← Back to Home

NN VOCABULARY

Complete neural network reference guide from psychomonkeyinc/nn_vocab

The Complete Neural Network Vocabulary

Your Comprehensive Guide to Modern Neural Architecture

A living encyclopedia of every neural network component, architecture, optimizer, loss function, and theoretical framework. From foundational building blocks to cutting-edge research. This is the reference that connects the dots between what networks do, why they work, and how to use them.

Whether you're debugging a Transformer at 3 AM, designing a new architecture, or just trying to understand what "grouped-query attention with rotary embeddings" actually means, this guide has your back. Written for practitioners who build real systems, not just read papers.


Visual Architecture Examples

To understand how neural networks actually work, let's start with visual examples of complete architectures. These show the full structure from input to output.

Example 1: O3 Deep Reasoning Model with Scratchpad

The O3 model uses chain-of-thought reasoning with an explicit scratchpad for intermediate computations.

┌─────────────────────────────────────────────────────────────────┐
│                    O3 DEEP REASONING MODEL                      │
└─────────────────────────────────────────────────────────────────┘

INPUT: "What is 47 × 23?"

┌──────────────────────────────────────────────────────────────────┐
│  ENCODER (Bidirectional Transformer)                             │
│  ┌────────┐  ┌────────┐  ┌────────┐  ┌────────┐               │
│  │ Token  │→│ Multi  │→│ FFN    │→│ Layer  │  (×12 layers)   │
│  │ Embed  │  │ Head   │  │        │  │ Norm   │               │
│  └────────┘  │ Attn   │  └────────┘  └────────┘               │
│              └────────┘                                          │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  REASONING SCRATCHPAD (Autoregressive Decoder)                   │
│  ┌───────────────────────────────────────────────────────────┐  │
│  │ Step 1: "Break down: 47 × 23 = 47 × 20 + 47 × 3"         │  │
│  │ ↓ [Causal Attention] ↓                                     │  │
│  │ Step 2: "Calculate 47 × 20 = 940"                         │  │
│  │ ↓ [Cross-Attention to Encoder] ↓                          │  │
│  │ Step 3: "Calculate 47 × 3 = 141"                          │  │
│  │ ↓ [Self-Attention on scratchpad] ↓                        │  │
│  │ Step 4: "Sum: 940 + 141 = 1081"                           │  │
│  └───────────────────────────────────────────────────────────┘  │
│                                                                   │
│  Each step uses:                                                 │
│  • Causal masking (can't see future steps)                      │
│  • Cross-attention to input encoding                            │
│  • FFN for computation                                          │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  OUTPUT HEAD                                                      │
│  ┌──────────────────────────────────────────────┐               │
│  │ Final Answer: "1081"                         │               │
│  │ Confidence: 0.98                             │               │
│  │ Reasoning Chain: [4 steps shown above]      │               │
│  └──────────────────────────────────────────────┘               │
└──────────────────────────────────────────────────────────────────┘

Key Components:
• Encoder processes input bidirectionally (BERT-style)
• Scratchpad generates reasoning steps autoregressively (GPT-style)
• Cross-attention connects scratchpad to encoded input
• Each step conditions on previous steps via causal masking
• Final answer extracted from last scratchpad token

Why this architecture works:

  • Explicit reasoning steps make model interpretable
  • Chain-of-thought improves complex problem solving
  • Scratchpad acts as working memory for multi-step reasoning

Example 2: Tree of Thought with Parent's Attention

The Tree of Thought model with Parent's Attention (eye-in-back-of-head style) explores multiple reasoning paths while maintaining backward awareness for error correction.

┌─────────────────────────────────────────────────────────────────┐
│         TREE OF THOUGHT WITH PARENT'S ATTENTION                  │
└─────────────────────────────────────────────────────────────────┘

INPUT: "If 3x + 7 = 22, what is x?"

┌──────────────────────────────────────────────────────────────────┐
│  ENCODER (Context Understanding)                                 │
│  ┌────────┐  ┌────────┐  ┌────────┐                            │
│  │ Token  │→│ Multi  │→│ FFN    │  (×6 layers)                │
│  │ Embed  │  │ Head   │  │        │                             │
│  └────────┘  │ Attn   │  └────────┘                            │
│              └────────┘                                          │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  TREE EXPANSION WITH PARENT'S ATTENTION                          │
│                                                                   │
│  Root: "Solve 3x + 7 = 22"                                       │
│         ↓                                                         │
│    ┌────┴────┬────────┬────────┐                                │
│    ↓         ↓        ↓        ↓                                 │
│  Path1:   Path2:   Path3:   Path4:                              │
│  "Sub 7   "Div by  "Trial   "Factor                             │
│   both    3 first  & error" approach"                           │
│   sides"  (wrong)"           (complex)                           │
│    ↓         ↓        ↓        ↓                                 │
│  Step:    Step:    Step:    Step:                               │
│  "3x=15"  "x+7/3=" "Try x=5" "..."                              │
│    ↓         ↑ ⚠️     ↓        ↓                                 │
│    ↓      PARENT'S   ↓        ↓                                 │
│    ↓      ATTENTION  ↓        ↓                                 │
│    ↓      TRIGGERED! ↓        ↓                                 │
│    ↓         ↓        ↓        ↓                                 │
│                                                                   │
│  ┌──────────────────────────────────────────────────────────┐  │
│  │ PARENT'S ATTENTION MECHANISM (Eye-in-Back-of-Head)       │  │
│  │                                                            │  │
│  │  Forward Computation: ────────────────────────→           │  │
│  │                                                            │  │
│  │  Backward Monitoring: ←─────────────── (vigilant)        │  │
│  │                        ╱                                   │  │
│  │  Anomaly Score: ∑ᵢ |grad(stepᵢ) - E[grad]| > θ           │  │
│  │                                              ↓             │  │
│  │  IF anomaly detected:                    STOP!            │  │
│  │    1. Halt forward pass                                   │  │
│  │    2. Add error to scratchpad                            │  │
│  │    3. Backtrack to last valid state                      │  │
│  │    4. Prune invalid branch                               │  │
│  │    5. Resume from checkpoint                             │  │
│  │                                                            │  │
│  │  Random Backward Pass (p=0.1):                           │  │
│  │    • Validates even "correct-looking" paths              │  │
│  │    • Creates training data from mistakes                 │  │
│  │    • Learns error recovery patterns                      │  │
│  └──────────────────────────────────────────────────────────┘  │
│                                                                   │
│  Path2 PRUNED: ❌ "Dividing by 3 first is wrong order"          │
│  ↓ (Added to scratchpad for learning)                           │
│                                                                   │
│  Path1 continues: "3x = 15" → "x = 5" ✓                         │
│  Path3 continues: "Try x=5: 3(5)+7 = 22 ✓" → "x = 5" ✓          │
│  Path4 pruned: Too complex, not converging                      │
│                                                                   │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  SCRATCHPAD (Learning from Mistakes)                             │
│  ┌───────────────────────────────────────────────────────────┐  │
│  │ Mistake Log:                                               │  │
│  │ • Path2: "Cannot divide by 3 before isolating 3x term"   │  │
│  │   Reason: "Order of operations violated"                  │  │
│  │   Correction: "Must subtract first, then divide"          │  │
│  │                                                            │  │
│  │ Successful Paths:                                          │  │
│  │ • Path1 (algebraic): 3x+7=22 → 3x=15 → x=5               │  │
│  │ • Path3 (verification): Trial x=5 confirms answer        │  │
│  │                                                            │  │
│  │ Training Data Generated:                                   │  │
│  │ Input: "3x + 7 = 22"                                      │  │
│  │ Wrong Path: [Path2 steps]                                 │  │
│  │ Error Signal: -0.8 (strong negative)                      │  │
│  │ Correct Path: [Path1 steps]                               │  │
│  │ Reward: +1.0                                              │  │
│  └───────────────────────────────────────────────────────────┘  │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  OUTPUT WITH CONFIDENCE                                           │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ Answer: x = 5                                             │   │
│  │ Confidence: 0.95 (verified by 2 independent paths)       │   │
│  │ Reasoning:                                                 │   │
│  │   Path1 (algebraic): 3x + 7 = 22 → 3x = 15 → x = 5      │   │
│  │   Path3 (verification): 3(5) + 7 = 15 + 7 = 22 ✓        │   │
│  │ Rejected Paths: 1 (order of operations error)            │   │
│  │ Self-Correction: 1 error caught and corrected            │   │
│  └──────────────────────────────────────────────────────────┘   │
└──────────────────────────────────────────────────────────────────┘

Mathematical Formulation of Parent's Attention:

python
# Forward pass computation
for step_t in reasoning_path:
    h_t = forward_step(h_{t-1}, context)  # Normal forward computation
    
    # Parent's Attention: Monitor backward gradient signal
    backward_score = compute_backward_anomaly(h_t, h_{t-1}, expected_grad)
    
    # Anomaly detection
    if backward_score > threshold:
        # STOP! Something's wrong
        scratchpad.append({
            'step': step_t,
            'error': backward_score,
            'context': get_context(h_{t-k:t})
        })
        # Backtrack and prune branch
        prune_branch(step_t)
        return checkpoint_state
    
    # Random backward verification (p=0.1)
    if random() < p_verify:
        validation_score = verify_path_correctness(h_{0:t})
        if validation_score < valid_threshold:
            # Caught a subtle error - add to training data
            training_data.append({
                'input': original_input,
                'wrong_path': h_{0:t},
                'error_type': classify_error(h_{0:t}),
                'correction_needed': True
            })
            # Continue but flag this path as suspicious
            path_confidence *= 0.8

# Parent's Attention Score Computation
def compute_backward_anomaly(h_t, h_{t-1}, expected_grad):
    """
    Computes how much current step deviates from expected reasoning
    by looking at gradient signals from backward direction
    """
    # Backward attention weights
    α_backward = softmax(Q_backward @ K_past^T / √d_k)
    
    # Expected vs actual gradient magnitude
    grad_actual = ||∇h_t||_2
    grad_expected = E[||∇h_{t-k:t-1}||_2]
    
    # Deviation score
    deviation = |grad_actual - grad_expected| / grad_expected
    
    # Logical consistency check (backward reasoning)
    consistency = logical_validator(h_{t-1} → h_t)
    
    # Combined anomaly score
    return (deviation + (1 - consistency)) / 2

PyTorch Implementation Example:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class ParentsAttention(nn.Module):
    """
    Eye-in-back-of-head attention mechanism.
    Monitors backward context while computing forward.
    """
    def __init__(self, d_model=512, n_heads=8, threshold=0.5, p_verify=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.threshold = threshold
        self.p_verify = p_verify
        
        # Forward attention (standard)
        self.forward_attn = nn.MultiheadAttention(d_model, n_heads)
        
        # Backward monitoring attention
        self.backward_attn = nn.MultiheadAttention(d_model, n_heads)
        
        # Anomaly detection network
        self.anomaly_detector = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1),
            nn.Sigmoid()
        )
        
        # Gradient expectation (running average)
        self.register_buffer('expected_grad_norm', torch.tensor(1.0))
        self.register_buffer('grad_history', torch.zeros(100))
        self.register_buffer('grad_idx', torch.tensor(0, dtype=torch.long))
        
    def forward(self, query, key, value, past_hidden_states):
        """
        Args:
            query: Current step query [seq_len, batch, d_model]
            key, value: Context for attention
            past_hidden_states: Previous steps [history_len, batch, d_model]
        
        Returns:
            output: Forward attention output
            should_stop: Boolean indicating if anomaly detected
            anomaly_score: Deviation score
        """
        batch_size = query.size(1)
        
        # Standard forward attention
        forward_out, _ = self.forward_attn(query, key, value)
        
        # Parent's Attention: Look backward
        if past_hidden_states is not None and past_hidden_states.size(0) > 0:
            # Attend to past hidden states
            backward_out, backward_weights = self.backward_attn(
                query, past_hidden_states, past_hidden_states
            )
            
            # Compute anomaly score
            combined = torch.cat([forward_out, backward_out], dim=-1)
            anomaly_score = self.anomaly_detector(combined).squeeze(-1)
            
            # Check activation magnitude deviation (proxy for gradient issues)
            # Note: Actual gradient checks happen during backward pass hooks
            if self.training:
                activation_norm = torch.norm(forward_out, dim=-1).mean()
                activation_deviation = abs(activation_norm - self.expected_grad_norm) / (self.expected_grad_norm + 1e-8)
                
                # Update running average
                self.grad_history[self.grad_idx % 100] = activation_norm.detach()
                self.grad_idx.add_(1)
                self.expected_grad_norm = self.grad_history.mean()
                
                # Combine with attention-based anomaly
                total_anomaly = (anomaly_score + activation_deviation) / 2
            else:
                total_anomaly = anomaly_score
            
            # Random backward verification (with deterministic option)
            should_verify = (torch.rand(1, device=query.device).item() < self.p_verify) if self.training else False
            if should_verify:
                # Additional verification pass
                verification_score = self._verify_reasoning_path(
                    past_hidden_states, forward_out
                )
                # Ensure same shape for max operation
                total_anomaly = torch.maximum(total_anomaly, verification_score.expand_as(total_anomaly))
            
            # Determine if we should stop and backtrack
            should_stop = (total_anomaly > self.threshold).any()
            
            return forward_out, should_stop, total_anomaly.mean()
        
        # No history yet, just return forward
        return forward_out, False, torch.tensor(0.0, device=query.device)
    
    def _verify_reasoning_path(self, past_states, current_state):
        """
        Randomly verify if the reasoning path makes sense.
        This creates training data from potential mistakes.
        """
        # Simplified consistency check
        # In practice, this would be a learned verification network
        past_mean = past_states.mean(dim=0)  # [batch, d_model]
        current_mean = current_state.mean(dim=0)  # [batch, d_model]
        
        # Ensure shapes match completely for cosine similarity
        if past_mean.shape != current_mean.shape:
            # Handle any shape mismatch by taking minimum dimensions
            # This is a fallback; ideally shapes should always match
            if past_mean.dim() != current_mean.dim():
                # Dimension mismatch - return high error signal
                return torch.tensor(1.0, device=current_state.device).unsqueeze(0)
            
            # Match batch and feature dimensions
            min_batch = min(past_mean.size(0), current_mean.size(0))
            min_features = min(past_mean.size(-1), current_mean.size(-1))
            past_mean = past_mean[:min_batch, ..., :min_features]
            current_mean = current_mean[:min_batch, ..., :min_features]
        
        # Check if current state is consistent with past trajectory
        consistency = F.cosine_similarity(
            past_mean, current_mean, dim=-1
        ).abs()
        
        # Low consistency = potential error
        # Return as tensor matching input device
        return (1.0 - consistency).mean().unsqueeze(0)


class TreeOfThoughtWithParentsAttention(nn.Module):
    """
    Complete Tree of Thought model with Parent's Attention.
    """
    def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=6, max_branches=4):
        super().__init__()
        self.d_model = d_model
        self.max_branches = max_branches
        
        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(512, d_model))
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, n_heads)
            for _ in range(n_layers)
        ])
        
        # Tree expansion with Parent's Attention
        self.tree_layers = nn.ModuleList([
            ParentsAttention(d_model, n_heads)
            for _ in range(n_layers)
        ])
        
        # Branch scoring
        self.branch_scorer = nn.Linear(d_model, 1)
        
    def forward(self, x, explore_branches=True):
        """
        Args:
            x: Input tokens [batch, seq_len]
            explore_branches: Whether to explore multiple reasoning paths
        
        Returns:
            output: Best reasoning path
            scratchpad: Collected mistakes and corrections
        """
        batch_size = x.size(0)
        
        # Initialize scratchpad for this forward pass
        # Note: This implementation is not thread-safe. For concurrent inference,
        # use model.eval() mode and separate model instances per thread
        scratchpad = []
        
        # Encode input
        h = self.embedding(x) + self.pos_encoding[:x.size(1)]
        for layer in self.encoder_layers:
            h = layer(h)
        
        if explore_branches:
            # Tree of Thought: Explore multiple paths
            paths = self._explore_tree(h, depth=0, scratchpad=scratchpad)
            
            # Select best path
            best_path = max(paths, key=lambda p: p['score'])
            
            return best_path['output'], scratchpad
        else:
            # Single path reasoning
            output = self._single_path_reasoning(h)
            return output, scratchpad
    
    def _explore_tree(self, h, depth, scratchpad, max_depth=5):
        """Explore multiple reasoning paths with Parent's Attention"""
        if depth >= max_depth:
            return [{'output': h, 'score': 0.0, 'path': []}]
        
        paths = []
        past_states = None
        
        # Generate multiple branch hypotheses
        for branch_id in range(self.max_branches):
            branch_h = h.clone()
            branch_path = []
            branch_valid = True
            
            # Process this branch with Parent's Attention
            for layer in self.tree_layers:
                # Self-attention: Q, K, V all from current branch state
                # This allows the model to refine its current reasoning
                step_out, should_stop, anomaly = layer(
                    branch_h.transpose(0, 1),  # Query: current state
                    branch_h.transpose(0, 1),  # Key: current state
                    branch_h.transpose(0, 1),  # Value: current state
                    past_states                # Past states for Parent's Attention
                )
                
                branch_h = step_out.transpose(0, 1)
                
                # Parent's Attention detected anomaly!
                if should_stop:
                    # Log to scratchpad
                    scratchpad.append({
                        'branch': branch_id,
                        'depth': depth,
                        'anomaly_score': anomaly.item(),
                        'action': 'pruned',
                        'reason': 'backward_anomaly_detected'
                    })
                    branch_valid = False
                    break
                
                branch_path.append(branch_h)
                # Stack states along time dimension, ensure consistent shapes
                if branch_path:
                    # Filter states with matching shape and ensure we have at least one
                    matching_states = [s for s in branch_path if s.shape == branch_h.shape]
                    past_states = torch.stack(matching_states) if matching_states else None
            
            if branch_valid:
                # Score this branch
                score = self.branch_scorer(branch_h.mean(dim=1)).mean()
                paths.append({
                    'output': branch_h,
                    'score': score.item(),
                    'path': branch_path
                })
        
        return paths if paths else [{'output': h, 'score': -1.0, 'path': []}]

Key Components:Forward Active, Backward Vigilant: Computes forward while monitoring backward gradients • Anomaly Detection: Stops when gradient deviation or logical inconsistency detected • Scratchpad Learning: Records mistakes for training data generation • Random Verification: Probabilistically validates correct-looking paths (p=0.1) • Branch Pruning: Eliminates invalid reasoning paths early • Error Recovery: Backtracks to last valid checkpoint and resumes

Why this architecture works:

  • Parent's Attention catches errors before they propagate
  • Tree exploration finds multiple solution paths
  • Random backward passes create diverse training data
  • Learning from mistakes improves error correction
  • Scratchpad accumulates meta-learning patterns
  • Combines tree search with self-correction

Use cases:

  • Mathematical reasoning with error checking
  • Code generation with syntax validation
  • Planning tasks with constraint verification
  • Any domain where intermediate mistakes are common and costly

Training strategy:

python
# Training loop with mistake injection
for batch in dataloader:
    output, scratchpad = model(batch, explore_branches=True)
    
    # Loss on correct outputs
    loss_correct = criterion(output, target)
    
    # Contrastive loss on mistakes (learn to avoid them)
    loss_mistakes = 0
    for mistake in scratchpad:
        if mistake['action'] == 'pruned':
            # Penalize the model for even considering this path
            loss_mistakes += mistake['anomaly_score']
    
    # Combined loss
    loss = loss_correct + 0.1 * loss_mistakes
    loss.backward()
    optimizer.step()

Performance characteristics:

  • Computational cost: 2-3× standard transformer (backward monitoring overhead)
  • Memory: O(depth × branches) for tree exploration
  • Accuracy: +15-25% on mathematical reasoning tasks
  • Self-correction rate: Catches ~80% of errors before final answer

Example 3: Tree of Thought (Standard)

The standard Tree of Thought explores multiple reasoning paths in parallel without backward monitoring, selecting the best path through scoring.

┌─────────────────────────────────────────────────────────────────┐
│              TREE OF THOUGHT (STANDARD)                          │
└─────────────────────────────────────────────────────────────────┘

INPUT: "Find the path from A to D with minimum cost"

┌──────────────────────────────────────────────────────────────────┐
│  ENCODER (Problem Understanding)                                 │
│  ┌────────┐  ┌────────┐  ┌────────┐                            │
│  │ Token  │→│ Multi  │→│ FFN    │  (×6 layers)                │
│  │ Embed  │  │ Head   │  │        │                             │
│  └────────┘  │ Attn   │  └────────┘                            │
│              └────────┘                                          │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  THOUGHT TREE EXPANSION                                          │
│                                                                   │
│  Root: "Start at A, find path to D"                             │
│         ↓                                                         │
│    ┌────┴────┬────────┬────────┐                                │
│    ↓         ↓        ↓        ↓                                 │
│  Path1:   Path2:   Path3:   Path4:                              │
│  A→B→D    A→C→D    A→B→C→D  A→D                                 │
│  (cost:8) (cost:7) (cost:9)  (cost:10)                         │
│    ↓         ↓        ↓        ↓                                 │
│                                                                   │
│  Each path explores independently:                               │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ Path Generation:                                         │   │
│  │                                                           │   │
│  │ For each branch:                                         │   │
│  │   1. Generate next possible steps                       │   │
│  │   2. Evaluate intermediate states                       │   │
│  │   3. Continue exploration until goal or depth limit     │   │
│  │   4. Calculate path score/cost                          │   │
│  │                                                           │   │
│  │ No pruning during exploration - all paths complete      │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                   │
│  All paths explored to completion:                               │
│  Path1: A→B→D       Score: 0.85  Cost: 8  ✓                     │
│  Path2: A→C→D       Score: 0.92  Cost: 7  ✓✓ (BEST)             │
│  Path3: A→B→C→D     Score: 0.75  Cost: 9  ✓                     │
│  Path4: A→D         Score: 0.60  Cost: 10 ✓                     │
│                                                                   │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  PATH SCORING & SELECTION                                         │
│  ┌───────────────────────────────────────────────────────────┐  │
│  │ Score each complete path:                                  │  │
│  │                                                             │  │
│  │ Score = w₁·validity + w₂·efficiency + w₃·confidence       │  │
│  │                                                             │  │
│  │ Path1: 0.8·1.0 + 0.1·0.8 + 0.1·0.9 = 0.85                 │  │
│  │ Path2: 0.8·1.0 + 0.1·1.0 + 0.1·0.8 = 0.92 ← Best         │  │
│  │ Path3: 0.8·1.0 + 0.1·0.6 + 0.1·0.9 = 0.75                 │  │
│  │ Path4: 0.8·0.7 + 0.1·0.5 + 0.1·1.0 = 0.60                 │  │
│  │                                                             │  │
│  │ Select path with highest score                             │  │
│  └───────────────────────────────────────────────────────────┘  │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  OUTPUT                                                           │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ Best Path: A → C → D                                      │   │
│  │ Cost: 7                                                    │   │
│  │ Confidence: 0.92                                           │   │
│  │ Reasoning:                                                 │   │
│  │   - Explored 4 possible paths                             │   │
│  │   - A→C→D has lowest cost (7)                             │   │
│  │   - All intermediate steps validated                      │   │
│  │ Alternative paths available: [Path1, Path3, Path4]       │   │
│  └──────────────────────────────────────────────────────────┘   │
└──────────────────────────────────────────────────────────────────┘

PyTorch Implementation:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class TreeOfThought(nn.Module):
    """
    Standard Tree of Thought: explores multiple reasoning paths
    and selects the best one based on scoring.
    """
    def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=6, max_branches=4):
        super().__init__()
        self.d_model = d_model
        self.max_branches = max_branches
        
        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(512, d_model))
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, n_heads)
            for _ in range(n_layers)
        ])
        
        # Decoder for path generation
        self.decoder_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model, n_heads)
            for _ in range(n_layers)
        ])
        
        # Path scoring network
        self.path_scorer = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, 1),
            nn.Sigmoid()
        )
        
        # Output head
        self.output_head = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, max_depth=5):
        """
        Args:
            x: Input tokens [batch, seq_len]
            max_depth: Maximum depth for tree exploration
        
        Returns:
            best_path: Best reasoning path
            all_paths: All explored paths with scores
        """
        batch_size = x.size(0)
        
        # Encode input
        h = self.embedding(x) + self.pos_encoding[:x.size(1)]
        for layer in self.encoder_layers:
            h = layer(h)
        
        # Explore multiple paths
        all_paths = []
        for branch_id in range(self.max_branches):
            path = self._explore_single_path(h, branch_id, max_depth)
            all_paths.append(path)
        
        # Score all paths
        scored_paths = []
        for path in all_paths:
            score = self._score_path(path)
            scored_paths.append({
                'sequence': path['sequence'],
                'hidden_states': path['hidden_states'],
                'score': score,
                'branch_id': path['branch_id']
            })
        
        # Select best path
        best_path = max(scored_paths, key=lambda p: p['score'])
        
        return best_path, scored_paths
    
    def _explore_single_path(self, encoded_input, branch_id, max_depth):
        """
        Explore a single reasoning path to completion.
        No pruning - always explores to max depth or termination.
        """
        current_state = encoded_input
        sequence = []
        hidden_states = []
        
        for depth in range(max_depth):
            # Generate next step
            for layer in self.decoder_layers:
                current_state = layer(
                    current_state,
                    encoded_input
                )
            
            # Store state
            hidden_states.append(current_state)
            
            # Generate token
            logits = self.output_head(current_state[:, -1, :])
            next_token = torch.argmax(logits, dim=-1)
            sequence.append(next_token)
            
            # Check for termination token (e.g., EOS)
            if (next_token == 0).all():  # Assuming 0 is EOS
                break
        
        return {
            'sequence': torch.stack(sequence, dim=1) if sequence else None,
            'hidden_states': hidden_states,
            'branch_id': branch_id
        }
    
    def _score_path(self, path):
        """
        Score a complete path based on final hidden state.
        """
        if not path['hidden_states']:
            return 0.0
        
        # Use final hidden state for scoring
        final_state = path['hidden_states'][-1]
        # Average across sequence and batch
        score = self.path_scorer(final_state.mean(dim=1)).mean()
        
        return score.item()


# Usage example
model = TreeOfThought(vocab_size=50000, d_model=512, n_heads=8)
input_tokens = torch.randint(0, 50000, (1, 20))  # [batch=1, seq_len=20]

best_path, all_paths = model(input_tokens, max_depth=5)

print(f"Best path score: {best_path['score']:.3f}")
print(f"Number of paths explored: {len(all_paths)}")
for i, path in enumerate(all_paths):
    print(f"  Path {i}: score={path['score']:.3f}")

Training Strategy:

python
# Train to maximize score of correct paths, minimize score of incorrect paths
for batch in dataloader:
    input_tokens, correct_output = batch
    
    # Explore tree
    best_path, all_paths = model(input_tokens)
    
    # Compute loss on best path
    loss_output = criterion(best_path['sequence'], correct_output)
    
    # Diversity loss: encourage exploration of different paths
    loss_diversity = 0
    for i, path_i in enumerate(all_paths):
        for j, path_j in enumerate(all_paths[i+1:], i+1):
            # Penalize similar paths
            similarity = F.cosine_similarity(
                path_i['hidden_states'][-1].mean(dim=1),
                path_j['hidden_states'][-1].mean(dim=1),
                dim=-1
            ).mean()
            loss_diversity -= similarity  # Negative = encourage diversity
    
    loss = loss_output + 0.1 * loss_diversity
    loss.backward()
    optimizer.step()

Key Differences from Parent's Attention variant:

  • No early pruning: All paths explored to completion
  • Post-hoc selection: Best path chosen after all exploration
  • Simpler: No backward monitoring or anomaly detection
  • More compute: Explores all branches fully, even suboptimal ones
  • Good for: Problems where all paths provide useful signal

Why this architecture works:

  • Parallel exploration covers multiple reasoning strategies
  • Post-hoc scoring avoids premature path elimination
  • Diversity encourages complementary reasoning approaches
  • Simpler than self-correcting variants but still effective

Use cases:

  • Problem solving where multiple valid approaches exist
  • Creative tasks requiring diverse solutions
  • Search problems with well-defined scoring functions
  • Scenarios where exploration cost is acceptable

Performance characteristics:

  • Computational cost: max_branches × max_depth × model_forward
  • Memory: O(branches × depth) for all paths
  • Accuracy: Good for problems with clear scoring metrics
  • Exploration: Exhaustive within branch/depth limits

Example 4: Chain of Thought with Parent's Attention

Parent's Attention applied to linear Chain of Thought reasoning adds backward error monitoring to sequential reasoning without tree exploration.

┌─────────────────────────────────────────────────────────────────┐
│       CHAIN OF THOUGHT WITH PARENT'S ATTENTION                   │
└─────────────────────────────────────────────────────────────────┘

INPUT: "Calculate: (15 + 23) × 2 - 10"

┌──────────────────────────────────────────────────────────────────┐
│  ENCODER (Context Understanding)                                 │
│  ┌────────┐  ┌────────┐                                         │
│  │ Token  │→│ Multi  │  (×6 layers)                            │
│  │ Embed  │  │ Head   │                                         │
│  └────────┘  │ Attn   │                                         │
│              └────────┘                                          │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  CHAIN OF THOUGHT REASONING (with Parent's Attention)            │
│                                                                   │
│  Step 1: "First, add 15 + 23"                                    │
│          ↓ (forward reasoning)                                   │
│          Result: 38  ← Parent's Attention: ✓ Looks good         │
│          ↓                                                        │
│  Step 2: "Now multiply 38 × 2"                                   │
│          ↓ (forward reasoning)                                   │
│          Result: 76  ← Parent's Attention: ✓ Consistent         │
│          ↓                                                        │
│  Step 3: "Finally subtract 10"                                   │
│          ↓ (forward reasoning)                                   │
│          Result: 66  ← Parent's Attention: ✓ Valid              │
│                                                                   │
│  ┌──────────────────────────────────────────────────────────┐  │
│  │ PARENT'S ATTENTION MONITORING (at each step)             │  │
│  │                                                            │  │
│  │ At Step t:                                                 │  │
│  │   Forward: Generate next reasoning step                   │  │
│  │   Backward: Check consistency with previous steps         │  │
│  │                                                            │  │
│  │   Anomaly Check:                                           │  │
│  │   • Does step t follow from step t-1?                     │  │
│  │   • Is the logic sound?                                    │  │
│  │   • Does activation pattern match expected?                │  │
│  │                                                            │  │
│  │   If anomaly detected → STOP, backtrack, restart          │  │
│  │   If looks good → Continue to next step                    │  │
│  └──────────────────────────────────────────────────────────┘  │
│                                                                   │
│  EXAMPLE WITH ERROR DETECTION:                                   │
│                                                                   │
│  Step 1: "First, add 15 + 23"                                    │
│          ↓                                                        │
│          Result: 38  ✓                                           │
│          ↓                                                        │
│  Step 2: "Now multiply 38 × 2"                                   │
│          ↓                                                        │
│          Result: 74  ⚠️ Parent's Attention: ANOMALY!            │
│          ↑                                                        │
│          │ Backward check: 38×2 should be 76, not 74            │
│          │ Activation deviation detected                         │
│          ↓                                                        │
│       BACKTRACK to Step 1, re-run Step 2                         │
│          ↓                                                        │
│  Step 2 (retry): "Multiply 38 × 2"                               │
│          ↓                                                        │
│          Result: 76  ✓ Parent's Attention: Fixed!               │
│          ↓                                                        │
│  Step 3: "Finally subtract 10"                                   │
│          ↓                                                        │
│          Result: 66  ✓                                           │
│                                                                   │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  OUTPUT WITH CORRECTION LOG                                       │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │ Final Answer: 66                                          │   │
│  │                                                            │   │
│  │ Reasoning Chain:                                           │   │
│  │   Step 1: 15 + 23 = 38                                    │   │
│  │   Step 2: 38 × 2 = 76 (corrected from 74)                │   │
│  │   Step 3: 76 - 10 = 66                                    │   │
│  │                                                            │   │
│  │ Self-Corrections: 1                                        │   │
│  │ Confidence: 0.95                                           │   │
│  └──────────────────────────────────────────────────────────┘   │
└──────────────────────────────────────────────────────────────────┘

PyTorch Implementation:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChainOfThoughtWithParentsAttention(nn.Module):
    """
    Chain of Thought reasoning with Parent's Attention monitoring.
    Linear sequential reasoning with backward error checking.
    """
    def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=6, threshold=0.5):
        super().__init__()
        self.d_model = d_model
        self.threshold = threshold
        
        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(512, d_model))
        
        # Encoder
        self.encoder = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, n_heads)
            for _ in range(n_layers)
        ])
        
        # Decoder with Parent's Attention
        self.decoder_layers = nn.ModuleList([
            ParentsAttentionLayer(d_model, n_heads, threshold)
            for _ in range(n_layers)
        ])
        
        # Output head
        self.output_head = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, max_steps=10):
        """
        Generate chain of thought with backward monitoring.
        
        Args:
            x: Input tokens [batch, seq_len]
            max_steps: Maximum reasoning steps
        
        Returns:
            reasoning_chain: List of reasoning steps
            corrections: List of corrections made
        """
        batch_size = x.size(0)
        
        # Encode input
        h = self.embedding(x) + self.pos_encoding[:x.size(1)]
        for layer in self.encoder:
            h = layer(h)
        
        # Generate chain of thought with monitoring
        reasoning_chain = []
        corrections = []
        step_history = []
        
        current_step = 0
        while current_step < max_steps:
            # Generate next reasoning step
            step_result = self._generate_step(h, step_history)
            
            # Parent's Attention check
            if len(step_history) > 0:
                is_valid, anomaly_score = self._check_consistency(
                    step_result, step_history
                )
                
                if not is_valid:
                    # Error detected! Backtrack
                    corrections.append({
                        'step': current_step,
                        'anomaly_score': anomaly_score,
                        'previous_output': step_result['output'],
                        'action': 'backtrack_and_retry'
                    })
                    
                    # Retry this step (simplified - in practice, may adjust)
                    step_result = self._generate_step(h, step_history, retry=True)
            
            # Add to chain
            reasoning_chain.append(step_result)
            step_history.append(step_result['hidden_state'])
            
            # Check for termination
            if self._is_terminal(step_result):
                break
            
            current_step += 1
        
        return reasoning_chain, corrections
    
    def _generate_step(self, encoded_input, history, retry=False):
        """Generate next reasoning step."""
        if history:
            context = torch.stack(history)
        else:
            context = None
        
        current_state = encoded_input
        for layer in self.decoder_layers:
            current_state, should_stop, anomaly = layer(
                current_state, context
            )
            
            # In chain of thought, we don't stop mid-layer
            # Just accumulate anomaly signals
        
        # Generate output token
        logits = self.output_head(current_state[:, -1, :])
        output_token = torch.argmax(logits, dim=-1)
        
        return {
            'output': output_token,
            'hidden_state': current_state,
            'logits': logits
        }
    
    def _check_consistency(self, current_step, history):
        """
        Check if current step is consistent with history.
        Uses Parent's Attention mechanism.
        """
        current_hidden = current_step['hidden_state']
        history_states = torch.stack(history)
        
        # Compute consistency score
        current_mean = current_hidden.mean(dim=1)
        history_mean = history_states.mean(dim=(0, 1))
        
        consistency = F.cosine_similarity(
            current_mean, history_mean.expand_as(current_mean), dim=-1
        ).mean()
        
        anomaly_score = 1.0 - consistency
        is_valid = anomaly_score < self.threshold
        
        return is_valid, anomaly_score.item()
    
    def _is_terminal(self, step_result):
        """Check if this is a terminal step (e.g., final answer)."""
        # Simplified: check for EOS token
        return (step_result['output'] == 0).any()


class ParentsAttentionLayer(nn.Module):
    """Single layer with Parent's Attention monitoring."""
    def __init__(self, d_model, n_heads, threshold=0.5):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads)
        self.backward_attention = nn.MultiheadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.threshold = threshold
        
    def forward(self, x, history=None):
        """Forward with backward monitoring."""
        # Forward attention
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        
        # Backward monitoring
        anomaly_score = 0.0
        should_stop = False
        
        if history is not None:
            backward_out, _ = self.backward_attention(x, history, history)
            # Check for anomalies
            deviation = torch.norm(backward_out - x, dim=-1).mean()
            anomaly_score = deviation
            should_stop = anomaly_score > self.threshold
        
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        
        return x, should_stop, anomaly_score

Key Components:Linear reasoning: Single path, step-by-step (no branching) • Backward monitoring: Each step checked against previous steps • Error correction: Backtracks and retries when anomaly detected • Simpler than tree: No branch exploration, just sequential with checks • Efficient: Lower compute than tree-based methods

Why this architecture works:

  • Combines simplicity of chain-of-thought with error detection
  • Catches arithmetic/logical errors in real-time
  • More efficient than tree search when path is mostly correct
  • Maintains interpretability of linear reasoning

Use cases:

  • Multi-step calculations where errors compound
  • Sequential reasoning tasks (planning, proof generation)
  • Any chain-of-thought application where reliability matters
  • Resource-constrained scenarios (cheaper than tree exploration)

Performance characteristics:

  • Computational cost: 1.5-2× standard chain-of-thought (backward checks)
  • Memory: O(max_steps) for history
  • Accuracy: +10-15% over standard chain-of-thought
  • Correction rate: Catches ~70% of intermediate errors

Comparison with other variants:

FeatureStandard CoTCoT + Parent's AttentionTree of Thought + Parent's Attention
BranchingNoneNoneMultiple paths
Error DetectionNoneAt each stepAt each step per branch
CorrectionNoneBacktrack & retryPrune branch
Compute1.5-2×2-3×
Best forSimple tasksSequential with errorsComplex exploration

Example 5: BERT Model Architecture

BERT (Bidirectional Encoder Representations from Transformers) uses bidirectional attention for deep language understanding.

┌─────────────────────────────────────────────────────────────────┐
│                      BERT ARCHITECTURE                           │
└─────────────────────────────────────────────────────────────────┘

INPUT: "[CLS] The cat sat [MASK] the mat [SEP]"

┌──────────────────────────────────────────────────────────────────┐
│  EMBEDDING LAYER                                                  │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐          │
│  │   Token      │+│  Position    │+│  Segment     │           │
│  │  Embedding   │ │  Embedding   │ │  Embedding   │           │
│  │  [30K vocab] │ │  [512 pos]   │ │  [2 segments]│           │
│  └──────────────┘  └──────────────┘  └──────────────┘          │
│         ↓                 ↓                 ↓                     │
│         └─────────────────┴─────────────────┘                    │
│                           ↓                                       │
│              Combined Embedding [768-dim]                        │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  TRANSFORMER ENCODER LAYERS (×12)                                │
│                                                                   │
│  Layer 1:  ┌─────────────────────────────────────────┐          │
│            │  Multi-Head Self-Attention (12 heads)   │          │
│            │  ┌───┐ ┌───┐ ┌───┐       ┌───┐        │          │
│            │  │ H1│ │ H2│ │ H3│  ...  │H12│        │          │
│            │  └─┬─┘ └─┬─┘ └─┬─┘       └─┬─┘        │          │
│            │    └─────┴─────┴───────────┘           │          │
│            │    All heads attend to ALL tokens      │          │
│            │    (bidirectional - no masking)        │          │
│            └─────────────────────────────────────────┘          │
│                           ↓ [Add & Norm]                         │
│            ┌─────────────────────────────────────────┐          │
│            │  Feed-Forward Network                   │          │
│            │  Linear(768→3072) → GELU → Linear(3072→768)        │
│            └─────────────────────────────────────────┘          │
│                           ↓ [Add & Norm]                         │
│            ... (repeat 12 times) ...                             │
│                                                                   │
│  Each token can attend to EVERY other token (bidirectional)     │
│  Attention matrix is fully connected:                            │
│     [CLS]  The  cat  sat [MASK] the  mat [SEP]                  │
│  [CLS] [1    1    1    1    1    1    1    1  ]                │
│  The   [1    1    1    1    1    1    1    1  ]                │
│  cat   [1    1    1    1    1    1    1    1  ]                │
│  ...                                                             │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  OUTPUT HEADS                                                     │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │  MLM Head (Masked Language Model)                       │    │
│  │  Predicts: [MASK] → "on" (probability: 0.87)           │    │
│  │  Linear(768 → 30K vocab) + Softmax                      │    │
│  └─────────────────────────────────────────────────────────┘    │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │  NSP Head (Next Sentence Prediction)                    │    │
│  │  Uses [CLS] token: IsNext? Yes (0.92)                  │    │
│  │  Linear(768 → 2) + Softmax                              │    │
│  └─────────────────────────────────────────────────────────┘    │
└──────────────────────────────────────────────────────────────────┘

Dimensions:
• Hidden size: 768 (base) or 1024 (large)
• Attention heads: 12 (base) or 16 (large)  
• Layers: 12 (base) or 24 (large)
• FFN intermediate: 3072 (4× hidden size)
• Max sequence length: 512 tokens
• Total parameters: 110M (base) or 340M (large)

Why this architecture works:

  • Bidirectional attention captures full context (past and future)
  • Masked language modeling learns deep representations
  • Pre-training on massive text enables transfer learning
  • [CLS] token aggregates sequence-level information

Example 6: 4D Dilated Convolution with Dendrite Branches & Hypergraph Tips

Advanced convolutional architecture inspired by biological neurons with dendritic processing.

┌─────────────────────────────────────────────────────────────────┐
│     4D DILATED CONVOLUTION WITH DENDRITE HYPERGRAPHS            │
└─────────────────────────────────────────────────────────────────┘

INPUT: 4D Tensor [Batch, Channels, Depth, Height, Width]
       (e.g., video: [B, C, Time, H, W])

┌──────────────────────────────────────────────────────────────────┐
│  BASE 4D DILATED CONVOLUTION                                     │
│                                                                   │
│  Kernel: 3×3×3×3 with dilation=2 in each dimension              │
│  ┌────┐     ┌────┐     ┌────┐                                   │
│  │ ○  │ ... │ ○  │ ... │ ○  │  (27 positions in 4D space)      │
│  └────┘     └────┘     └────┘                                   │
│     Receptive field: 5×5×5×5 with 3×3×3×3 parameters           │
│                                                                   │
│  Standard convolution output → Feature maps [B, C', D', H', W']  │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  DENDRITE BRANCH PROCESSING                                      │
│  Each output feature has multiple dendrite branches:             │
│                                                                   │
│       Main Soma (cell body)                                      │
│            ↑                                                      │
│    ┌───────┼───────┐                                            │
│    │       │       │                                            │
│  Branch1 Branch2 Branch3 ... BranchN                           │
│    │       │       │         │                                  │
│    ↓       ↓       ↓         ↓                                  │
│  [3×3×3] [5×5×5] [7×7×7]  [Dilated]  ← Different receptive fields│
│  Conv4D  Conv4D  Conv4D   Conv4D                               │
│                                                                   │
│  Each branch learns different spatiotemporal patterns:           │
│  • Branch 1: Local fine details (3×3×3×3)                       │
│  • Branch 2: Medium patterns (5×5×5×5)                          │
│  • Branch 3: Large patterns (7×7×7×7)                           │
│  • Branch 4: Dilated long-range (dilation=4)                    │
│                                                                   │
│  Branch outputs combined: Σᵢ wᵢ·branchᵢ(x) + b                  │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  HYPERGRAPH TIPS (Multi-way Connections)                         │
│                                                                   │
│  Traditional: neuron connects to neurons (pairwise)              │
│  Hypergraph: neuron connects to SETS of neurons (higher-order)  │
│                                                                   │
│     ○────────○────────○   Standard edges (2-way)                │
│      ╲      ╱                                                    │
│       ╲    ╱                                                     │
│        ╲  ╱                                                      │
│         ○      Hyperedge (multi-way connection)                 │
│        ╱│╲                                                       │
│       ╱ │ ╲                                                      │
│      ╱  │  ╲                                                     │
│     ○   ○   ○   Connected as a SET                              │
│                                                                   │
│  Implementation:                                                 │
│  ┌─────────────────────────────────────────────────────┐       │
│  │ For each hyperedge h connecting set S = {n₁,n₂,...,nₖ}:     │
│  │                                                       │       │
│  │ aggregate = ϕ({f(nᵢ) : nᵢ ∈ S})  ← Set aggregation  │       │
│  │           = max/mean/sum pooling over set           │       │
│  │                                                       │       │
│  │ message = ψ(aggregate)  ← Transform                  │       │
│  │                                                       │       │
│  │ for each nᵢ in S:                                    │       │
│  │     nᵢ_new = nᵢ + α·message  ← Update               │       │
│  └─────────────────────────────────────────────────────┘       │
│                                                                   │
│  Captures higher-order relationships:                            │
│  • Triplet interactions: (object, action, context)              │
│  • Scene composition: (foreground, background, lighting)         │
│  • Temporal groups: (past, present, future) together            │
└──────────────────────────────────────────────────────────────────┘

OUTPUT: Enhanced 4D features with:
        • Multi-scale dendrite processing
        • Higher-order hypergraph relationships
        • Biological inspiration

Code Example:
```python
class DendriticHypergraph4DConv(nn.Module):
    def __init__(self, in_channels, out_channels, num_branches=4):
        super().__init__()
        # Dendrite branches with different scales
        self.branches = nn.ModuleList([
            nn.Conv3d(in_channels, out_channels, kernel_size=3, dilation=1),
            nn.Conv3d(in_channels, out_channels, kernel_size=3, dilation=2),
            nn.Conv3d(in_channels, out_channels, kernel_size=3, dilation=4),
            nn.Conv3d(in_channels, out_channels, kernel_size=3, dilation=8),
        ])
        # Hypergraph connections (attention over sets)
        self.hypergraph_attn = nn.MultiheadAttention(out_channels, num_heads=8)
        
    def forward(self, x):
        # Process through dendrite branches
        branch_outputs = [branch(x) for branch in self.branches]
        
        # Combine branches (soma integration)
        combined = sum(branch_outputs) / len(branch_outputs)
        
        # Hypergraph message passing
        # Treats spatial positions as nodes in hypergraph
        b, c, d, h, w = combined.shape
        flat = combined.view(b, c, -1).permute(2, 0, 1)  # [nodes, batch, channels]
        hypergraph_out, _ = self.hypergraph_attn(flat, flat, flat)
        output = hypergraph_out.permute(1, 2, 0).view(b, c, d, h, w)
        
        return output
```_

Why this architecture works:

  • Dendrite branches capture multi-scale spatiotemporal patterns
  • Hypergraph tips model higher-order relationships beyond pairwise
  • Biological inspiration: mimics how real neurons process information
  • 4D convolution for video/volumetric data with temporal dynamics

Example 7: 11×11×11 Stacked 3D SOM in 3×3×3 Cube

Self-Organizing Map (SOM) arranged in 3D hierarchical structure.

┌─────────────────────────────────────────────────────────────────┐
│        11×11×11 STACKED 3D SOM IN 3×3×3 CUBE HIERARCHY         │
└─────────────────────────────────────────────────────────────────┘

STRUCTURE: 3×3×3 = 27 large cubes, each containing 11×11×11 = 1331 neurons

┌─────────────────────────────────────────────────────────────────┐
│  OVERALL 3×3×3 CUBE ARRANGEMENT                                 │
│                                                                   │
│     Front Layer        Middle Layer       Back Layer             │
│    ┌───┬───┬───┐     ┌───┬───┬───┐     ┌───┬───┬───┐          │
│    │ 1 │ 2 │ 3 │     │10 │11 │12 │     │19 │20 │21 │          │
│    ├───┼───┼───┤     ├───┼───┼───┤     ├───┼───┼───┤          │
│    │ 4 │ 5 │ 6 │     │13 │14 │15 │     │22 │23 │24 │          │
│    ├───┼───┼───┤     ├───┼───┼───┤     ├───┼───┼───┤          │
│    │ 7 │ 8 │ 9 │     │16 │17 │18 │     │25 │26 │27 │          │
│    └───┴───┴───┘     └───┴───┴───┘     └───┴───┴───┘          │
│                                                                   │
│  Each numbered cube contains a complete 11×11×11 SOM            │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  INSIDE EACH 11×11×11 SOM CUBE (e.g., Cube #14 - center)        │
│                                                                   │
│  Visualizing as layers (11 slices of 11×11):                    │
│                                                                   │
│  Layer 0 (front):        Layer 5 (middle):      Layer 10 (back):│
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ☆ ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ☆ ☆ ☆ ☆ ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ●  ● ● ● ● ● ● ● ● ● ● ● │
│                                                                   │
│  ● = neuron/cell    ☆ = activated neurons (Best Matching Units) │
│                                                                   │
│  Total neurons per cube: 11 × 11 × 11 = 1,331 neurons          │
│  Total neurons in system: 27 × 1,331 = 35,937 neurons          │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  HOW IT WORKS (Self-Organizing Map Algorithm)                    │
│                                                                   │
│  1. INPUT: High-dimensional vector x (e.g., 1000-D)             │
│     ↓                                                             │
│  2. COMPETITION: Find Best Matching Unit (BMU)                   │
│     BMU = argminᵢ ||x - wᵢ||  (closest neuron to input)         │
│     ☆ marks the BMU                                             │
│     ↓                                                             │
│  3. COOPERATION: Define neighborhood around BMU                  │
│     h(i, BMU) = exp(-||rᵢ - r_BMU||² / 2σ²)                     │
│     Nearby neurons in 3D grid cooperate                          │
│     ↓                                                             │
│  4. ADAPTATION: Update weights of BMU and neighbors              │
│     wᵢ(t+1) = wᵢ(t) + α(t)·h(i, BMU)·(x - wᵢ(t))               │
│     α(t) = learning rate (decreases over time)                   │
│     σ(t) = neighborhood radius (shrinks over time)              │
│     ↓                                                             │
│  5. HIERARCHICAL COMMUNICATION:                                  │
│     • Within-cube: Neurons organize similar inputs locally       │
│     • Between-cubes: 3×3×3 structure organizes at higher level  │
│     • Center cube (14) often learns most common patterns         │
│     • Corner cubes (1, 3, 7, 9, 19, 21, 25, 27) learn extremes │
└──────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────┐
│  INTER-CUBE CONNECTIONS (3×3×3 topology)                         │
│                                                                   │
│  Each cube connects to its neighbors in 3D grid:                 │
│  • Face neighbors (6): up, down, left, right, front, back       │
│  • Edge neighbors (12): diagonal edges                           │
│  • Corner neighbors (8): diagonal corners                        │
│                                                                   │
│  Example for center cube (14):                                   │
│  Face neighbors: 5, 11, 13, 15, 17, 23                          │
│  Edge neighbors: 2, 4, 6, 8, 10, 12, 16, 18, 20, 22, 24, 26     │
│  Corner neighbors: 1, 3, 7, 9, 19, 21, 25, 27                   │
│                                                                   │
│  Winner from one cube can propagate to neighboring cubes         │
└──────────────────────────────────────────────────────────────────┘

Code Example:
```python
import numpy as np

class Stacked3DSOM:
    def __init__(self, cube_size=11, meta_size=3, input_dim=1000):
        # 27 cubes of 11×11×11 SOMs
        self.cube_size = cube_size  # 11×11×11
        self.meta_size = meta_size  # 3×3×3
        self.input_dim = input_dim
        
        # Initialize weights for all neurons
        # Shape: [3, 3, 3, 11, 11, 11, input_dim]
        self.weights = np.random.randn(
            meta_size, meta_size, meta_size,  # Meta 3×3×3 structure
            cube_size, cube_size, cube_size,  # Each cube is 11×11×11
            input_dim  # Each neuron has input_dim weights
        )
        
    def find_bmu(self, x):
        """Find Best Matching Unit across all neurons"""
        # Compute distance from input to all neurons
        distances = np.sum((self.weights - x)**2, axis=-1)
        # Find coordinates of minimum distance neuron
        bmu_coords = np.unravel_index(np.argmin(distances), distances.shape)
        return bmu_coords
    
    def update(self, x, learning_rate=0.1, sigma=2.0):
        """Update SOM weights based on input"""
        # Find BMU
        bmu = self.find_bmu(x)
        
        # Create coordinate grids
        coords = np.mgrid[0:3, 0:3, 0:3, 0:11, 0:11, 0:11]
        
        # Compute distances from BMU in 6D lattice space
        distances = sum((coords[i] - bmu[i])**2 for i in range(6))
        
        # Neighborhood function
        h = np.exp(-distances / (2 * sigma**2))
        
        # Update weights
        h_expanded = h[..., np.newaxis]
        self.weights += learning_rate * h_expanded * (x - self.weights)
        
    def visualize_activation(self, x):
        """Show which cubes and neurons activate for input"""
        # Find BMU
        bmu = self.find_bmu(x)
        meta_cube = bmu[:3]  # Which of the 27 cubes
        within_cube = bmu[3:]  # Position within 11×11×11
        
        print(f"Activated meta-cube: {meta_cube} (position in 3×3×3)")
        print(f"BMU within cube: {within_cube} (position in 11×11×11)")
        return bmu
```*

Why this architecture works:

  • Hierarchical organization: 3×3×3 cubes for high-level, 11×11×11 for details
  • Topological preservation: Similar inputs mapped to nearby neurons
  • Unsupervised clustering: Discovers structure in data automatically
  • Dimensionality reduction: Projects high-D input to 3D visualization
  • Interpretable: Can visualize activation patterns in 3D space

Use cases:

  • Clustering high-dimensional data (genomics, sensor arrays)
  • Feature extraction and visualization
  • Anomaly detection (outliers activate corner cubes)
  • Hierarchical pattern recognition

Quick Glossary of Key Terms

Layer: A computational unit in a neural network that transforms input tensors to output tensors.

MLP (Multi-Layer Perceptron): Stack of linear layers with activation functions between them.

Head: In attention mechanisms, an independent attention computation. Multiple heads allow learning different representation subspaces.

Hyperparameter: Configuration value set before training (learning rate, batch size, etc.), not learned from data.

Parameter: Learnable weight or bias value optimized during training.

Token: Discrete unit of input (word, subword, character, or image patch) mapped to a vector.

Vector: Ordered list of numbers representing data in continuous space (e.g., [1.2, -0.5, 3.4]).

Tensor: Multi-dimensional array generalizing vectors and matrices (scalars are 0D, vectors 1D, matrices 2D, etc.).

Embedding: Mapping from discrete tokens to continuous vectors in learned representation space.

Kernel: Small matrix of learnable weights that slides over input during convolution operations.

Sparse: Having many zero values (opposite of dense). Reduces computation and memory.

Dense: Having mostly non-zero values. Synonym for fully connected.

Sparsity: Fraction of zero values in a tensor or network. High sparsity = more zeros.

Hidden State: Internal representation computed by a layer or recurrent network.

Feature: Measurable property or characteristic extracted by network layers.

Neuron: Single computational unit that applies weighted sum plus activation function.

Activation Function: Non-linear function applied element-wise to neuron outputs (e.g., ReLU, GELU).

Feedforward: Forward flow of information from input to output without cycles or loops.

Backpropagation: Algorithm for computing gradients by propagating errors backward through network.

Gradient: Vector of partial derivatives indicating how to update parameters to reduce loss.

Loss Function: Measures how far predictions are from targets. Optimization minimizes this.

Batch: Set of samples processed together in parallel during training or inference.

Epoch: One complete pass through entire training dataset.

Learning Rate: Hyperparameter controlling step size during parameter updates.

Overfitting: Model learns training data too well, failing to generalize to new data.

Regularization: Techniques to prevent overfitting (dropout, weight decay, etc.).

Dropout: Randomly zeros some neurons during training for regularization.

Attention: Mechanism for selectively focusing on relevant parts of input sequence.

Encoder: Network component that compresses input into latent representation.

Decoder: Network component that generates output from latent representation.

Autoregressive: Generates output one token at a time, conditioning on previous tokens.

Sequence: Ordered list of tokens or vectors (e.g., sentence, time series).

Context Window: Maximum sequence length a model can process at once.

Fine-tuning: Adapting pre-trained model to specific task with additional training.

Transfer Learning: Using knowledge from one task to improve performance on another.

Inference: Using trained model to make predictions on new data (not training).

Backpropagation Through Time (BPTT): Backpropagation variant for recurrent networks.


Table of Contents

  1. Quick Glossary
  2. Layer Types & Connectivity
  3. Convolutional Variants
  4. Recurrent & Temporal Architectures
  5. Attention Mechanisms
  6. Normalization Techniques
  7. Activation Functions
  8. Pooling & Downsampling
  9. Optimization Algorithms
  10. Loss Functions
  11. Regularization Techniques
  12. Graph Neural Networks
  13. Neuromorphic & Spiking Networks
  14. Continuous-Time & Liquid Networks
  15. Generative Models
  16. Meta-Learning & Hypernetworks
  17. Quantum & Optical Networks
  18. Specialized Architectures
  19. Hardware-Specific Optimizations
  20. Block Types & Architectural Patterns
  21. Embedding Strategies
  22. Data Flow Directions
  23. Theory of Mind & Cognitive Architectures
  24. Persona Manifolds & Identity Embeddings
  25. Guardrails & Safety Mechanisms

Layer Types & Connectivity

Linear (Dense / Fully Connected)

What it does: Matrix multiplication + bias. Maps input to output via learned weights. Output = xW^T + b where W is the weight matrix and b is the bias vector.

The math:

For input x ∈ ℝ^n, weights W ∈ ℝ^(m×n), bias b ∈ ℝ^m:
y = xW^T + b
Each neuron: y_i = Σ(w_ij * x_j) + b_i
python
# PyTorch implementation
nn.Linear(in_features=512, out_features=256, bias=True)

When to use: Final classification layers, MLPs (multi-layer perceptrons), feature transformation.

Use cases:

Hyperparameters:

Pairs with: Any optimizer; use weight decay for regularization.

Models using it: Every neural network (ResNet, BERT, GPT, ViT, etc.)

Related terms: Dense layer, fully connected layer, affine transformation

Papers: "Learning representations by back-propagating errors" (Rumelhart et al., 1986)

Sparse Linear

What it does: Linear layer with enforced sparsity (most weights = 0). Reduces memory and computation by keeping only important connections between neurons.

The math:

W_sparse = W ⊙ M  where M is binary mask
y = x(W ⊙ M)^T + b
Sparsity ratio = (# zeros) / (total elements)
python
# Custom implementation or use pruning
mask = torch.rand(512, 256) > 0.9  # 90% sparse (90% zeros)
sparse_weight = weight * mask
output = F.linear(input, sparse_weight, bias)

When to use: Memory-constrained environments, edge devices, lottery ticket hypothesis experiments, network pruning.

Use cases:

  • Mobile model deployment (reduce model size for phones/IoT)
  • Network pruning research (finding minimal winning subnetworks)
  • Finding optimal sparse subnetworks (lottery ticket hypothesis)
  • Efficient inference on resource-limited hardware

Hyperparameters:

  • sparsity_ratio: Fraction of weights set to zero (e.g., 0.9 = 90% sparse)
  • pruning_method: Magnitude-based, gradient-based, structured, or unstructured pruning
  • pruning_schedule: When and how much to prune during training

Pairs with: Magnitude pruning, L1 regularization (encourages sparsity).

Models using it: Pruned versions of ResNet, BERT; lottery ticket experiments.

Related terms: Pruning, structured sparsity, magnitude pruning

Papers: "The Lottery Ticket Hypothesis" (Frankle & Carbin, 2019)

Bilinear

What it does: Learns interactions between two input vectors. Computes pairwise multiplicative interactions.

The math:

For inputs x₁ ∈ ℝ^m, x₂ ∈ ℝ^n:
y = x₁^T W x₂ + b  where W ∈ ℝ^(m×n×k)
Or equivalently: y_k = Σᵢⱼ x₁ᵢ W_ijk x₂ⱼ + b_k
python
nn.Bilinear(in1_features=128, in2_features=128, out_features=64)
# Computes: output = x1^T @ weight @ x2 + bias

When to use: Modeling pairwise interactions, visual question answering, multimodal fusion, relation modeling.

Use cases:

  • Visual question answering (fusing image and text features)
  • Recommendation systems (user-item interactions)
  • Relation extraction (entity pair modeling)
  • Multimodal learning (combining different input types)

Hyperparameters:

Models using it: Visual question answering models, multimodal transformers.

Related terms: Pairwise interaction, multiplicative interaction, tensor product

Embedding

What it does: Maps discrete tokens (integers) to continuous vectors. Learnable lookup table for categorical data. Each token ID maps to a unique dense vector representation.

The math:

For token ID i ∈ {0, 1, ..., vocab_size-1}:
embedding[i] → v ∈ ℝ^d (d-dimensional vector)
Implemented as: E ∈ ℝ^(vocab_size × embedding_dim)
Lookup: v_i = E[i]  (index into embedding matrix)
python
embedding = nn.Embedding(num_embeddings=50000, embedding_dim=512, padding_idx=0)
token_ids = torch.tensor([1, 42, 100, 5])  # Batch of token IDs
vectors = embedding(token_ids)  # Output: [4, 512] tensor of vectors

When to use: NLP (word/subword tokens), categorical features, user/item IDs in recommender systems, entity embeddings.

Use cases:

Hyperparameters:

  • num_embeddings: Vocabulary size (50K typical for NLP, varies for categorical)
  • embedding_dim: Vector dimension (128-1024, often 512 or 768)
  • padding_idx: Index to keep at zero (padding tokens)
  • max_norm: If set, renormalizes embeddings exceeding this norm
  • norm_type: Type of norm for renormalization (2.0 for L2 norm)

Pairs with: AdamW, embedding dropout, sparse optimizers for rare tokens.

Models using it: Every language model (BERT, GPT, T5), recommender systems (NCF), graph neural networks (node embeddings).

Related terms: Word embeddings, token embeddings, entity embeddings, lookup table

Papers: "Efficient Estimation of Word Representations in Vector Space" (Mikolov et al., 2013 - Word2Vec)

Mixture of Experts (MoE)

What it does: Routes inputs to specialized sub-networks (experts) via learned gating. Only activates subset of experts per input (sparse activation). Enables massive scaling with controlled compute.

The math:

G(x) = Softmax(TopK(x·W_g))  # Gating network selects experts
y = Σᵢ G(x)ᵢ · Expertᵢ(x)     # Weighted combination of expert outputs

For n experts, top_k selected:
- Standard: All n experts computed (dense)
- Sparse MoE: Only top_k experts computed
- Compute per token: O(top_k) instead of O(n)
python
class MixtureOfExperts(nn.Module):
    def __init__(self, n_experts=8, d_model=512, top_k=2):
        super().__init__()
        # Each expert is a feedforward network
        self.experts = nn.ModuleList([FFN(d_model) for _ in range(n_experts)])
        # Gating network decides which experts to use
        self.gate = nn.Linear(d_model, n_experts)
        self.top_k = top_k
    
    def forward(self, x):
        # Compute gating scores for all experts
        gate_logits = self.gate(x)  # [batch, n_experts]
        
        # Select top-k experts per token (sparse activation)
        top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
        top_k_weights = F.softmax(top_k_logits, dim=-1)
        
        # Route to selected experts and combine outputs
        output = torch.zeros_like(x)
        for i in range(self.top_k):
            expert_idx = top_k_indices[:, i]
            expert_weight = top_k_weights[:, i:i+1]
            expert_out = torch.stack([self.experts[idx](x[j]) for j, idx in enumerate(expert_idx)])
            output += expert_weight * expert_out
        return output

When to use: Scaling models efficiently to billions/trillions of parameters while keeping computation per-token constant, specialized processing.

Use cases:

  • Massive language models (Switch Transformer 1.6T params)
  • Efficient scaling (more parameters without proportional compute increase)
  • Specialized processing (different experts for different input types)
  • Domain-specific models (experts for code, math, reasoning, etc.)

Hyperparameters:

  • n_experts: Number of expert networks (8-256 typical, up to 2048 in Switch-XXL)
  • top_k: Experts activated per token (1-2 typical, balances quality/compute)
  • capacity_factor: Expert capacity limit (1.0-1.5, prevents overload)
  • load_balancing_loss_weight: Encourages balanced expert usage (0.01 typical)
  • expert_dropout: Dropout rate within experts

Pairs with: Load balancing losses (auxiliary loss to prevent expert collapse), expert-specific regularization.

Models using it: Switch Transformer (1.6T params), GShard, GLaM, Mixtral 8x7B.

Related terms: Sparse gating, expert routing, conditional computation, expert choice routing

Papers: "Outrageously Large Neural Networks: The Sparsely-Gated MoE Layer" (Shazeer et al., 2017)

Capsule Networks

What it does: Groups neurons into capsules representing part-whole hierarchies.

python
# Dynamic routing between capsules
# Preserves spatial relationships

When to use: Object recognition with viewpoint invariance. Papers: "Dynamic Routing Between Capsules" (Sabour et al., 2017).

Low-Rank Adaptation (LoRA)

What it does: Fine-tunes large models by training low-rank decomposition matrices instead of full weights. The math: For a weight matrix W, learn ΔW = BA where B ∈ ℝ^(d×r), A ∈ ℝ^(r×k), and r << min(d,k)

python
# Using PEFT library
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,  # Low-rank dimension
    lora_alpha=32,  # Scaling factor
    target_modules=["q_proj", "v_proj"],  # Which layers to adapt
    lora_dropout=0.1
)
model = get_peft_model(base_model, config)

When to use: Fine-tuning LLMs (GPT, LLaMA) with limited compute, parameter-efficient transfer learning. Models using it: Most modern LLM fine-tuning (Alpaca, Vicuna), Stable Diffusion LoRA. Pairs with: AdamW, lower learning rates than full fine-tuning (1e-4 to 1e-3). Papers: "LoRA: Low-Rank Adaptation of Large Language Models" (Hu et al., 2021).

Adapter Layers

What it does: Inserts small trainable modules between frozen pre-trained layers.

python
class Adapter(nn.Module):
    def __init__(self, hidden_size, adapter_size=64):
        super().__init__()
        self.down_proj = nn.Linear(hidden_size, adapter_size)
        self.up_proj = nn.Linear(adapter_size, hidden_size)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        return x + self.up_proj(self.activation(self.down_proj(x)))

When to use: Multi-task learning, domain adaptation without retraining full model. Models using it: BERT adapters, T5 adapters, Vision Transformer adapters. Pairs with: Higher learning rate for adapters than base model.

Prefix Tuning

What it does: Prepends learnable continuous vectors (virtual tokens) to input sequences.

python
class PrefixTuning(nn.Module):
    def __init__(self, prefix_length, hidden_dim):
        super().__init__()
        self.prefix = nn.Parameter(torch.randn(prefix_length, hidden_dim))
    
    def forward(self, x):
        batch_size = x.size(0)
        prefix_expanded = self.prefix.unsqueeze(0).expand(batch_size, -1, -1)
        return torch.cat([prefix_expanded, x], dim=1)

When to use: Prompt engineering, controllable generation, task-specific fine-tuning. Models using it: GPT-style models, T5 for controllable generation. Papers: "Prefix-Tuning: Optimizing Continuous Prompts for Generation" (Li & Liang, 2021).

Prompt Tuning

What it does: Similar to prefix tuning but only learns the embedding layer prompts.

python
# Learnable prompt tokens added to vocabulary
prompt_embeddings = nn.Embedding(num_virtual_tokens, embedding_dim)

When to use: Parameter-efficient fine-tuning with minimal trainable params. Models using it: T5, GPT variants for few-shot learning.

Gated Linear Units (GLU)

What it does: Element-wise product of two linear transformations, one gated by sigmoid/softmax. The math: GLU(x) = (xW + b) ⊗ σ(xV + c)

python
class GLU(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.gate = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.linear(x) * torch.sigmoid(self.gate(x))

When to use: Language modeling, replacing FFN in Transformers. Models using it: GPT-2/3, PaLM (uses SwiGLU variant). Variants: ReGLU (ReLU gate), SwiGLU (Swish gate), GeGLU (GELU gate).

Expert Choice Routing

What it does: Experts select top-k tokens instead of tokens selecting experts (inverse MoE).

python
# Conceptual: Each expert picks its favorite tokens
expert_scores = compute_affinity(experts, tokens)  # Shape: [num_experts, num_tokens]
top_k_tokens = expert_scores.topk(k, dim=1)

When to use: Better load balancing than standard MoE, reducing token dropping. Models using it: Google's Switch Transformer variants.


Convolutional Variants

1D Convolution

What it does: Slides kernel over 1D sequences (time-series, audio, text). Extracts local patterns along single axis.

The math:

For input x ∈ ℝ^(L×C_in), kernel W ∈ ℝ^(k×C_in×C_out):
y[i,c] = Σⱼ Σₐ x[i+j, a] · W[j, a, c] + b[c]
where i is position, j is kernel offset, a is input channel
python
nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
# Input: [batch, 64, sequence_length]
# Output: [batch, 128, sequence_length]

When to use: Audio waveforms, DNA sequences, temporal signals, text convolution.

Use cases:

  • Audio processing (WaveNet, speech recognition)
  • Time-series analysis (sensor data, stock prices)
  • DNA/protein sequence analysis
  • Text classification (character/word-level CNNs)

Hyperparameters:

  • kernel_size: Size of convolving kernel (3, 5, 7 typical)
  • stride: Step size for kernel movement (1 = dense, 2 = downsample)
  • padding: Zero-padding added to input (maintains length if padding = kernel_size // 2)
  • dilation: Spacing between kernel elements (1 = standard, >1 = dilated)

Models using it: WaveNet, ByteNet, TCN (Temporal Convolutional Networks)

Related terms: Temporal convolution, causal convolution

2D Convolution

What it does: Standard spatial convolution for images. Kernel slides over height and width dimensions, extracting visual features.

The math:

For input x ∈ ℝ^(H×W×C_in), kernel W ∈ ℝ^(k_h×k_w×C_in×C_out):
y[i,j,c] = ΣₘΣₙΣₐ x[i+m, j+n, a] · W[m, n, a, c] + b[c]
Receptive field: k_h × k_w spatial region
python
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
# Input: [batch, 3, height, width] (RGB image)
# Output: [batch, 64, height, width] (feature maps)

When to use: Image classification, object detection, segmentation, any 2D visual task.

Use cases:

Hyperparameters:

  • kernel_size: Kernel spatial size (3×3 most common, 5×5, 7×7 for larger receptive fields)
  • stride: Downsampling factor (1 = same size, 2 = half size)
  • padding: Border padding (1 for 3×3 kernel maintains size)
  • dilation: Dilated convolution spacing
  • groups: Number of groups for grouped convolution

Models using it: ResNet, VGG, Inception, U-Net, virtually all CNNs

Related terms: Spatial convolution, feature map, receptive field

3D Convolution

What it does: Operates on volumetric data (video, medical scans). Kernel slides over height, width, AND depth/time dimensions.

The math:

For input x ∈ ℝ^(D×H×W×C_in), kernel W ∈ ℝ^(k_d×k_h×k_w×C_in×C_out):
y[i,j,k,c] = ΣₗΣₘΣₙΣₐ x[i+l, j+m, k+n, a] · W[l, m, n, a, c] + b[c]
Captures spatiotemporal patterns
python
nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
# Input: [batch, 1, depth, height, width] or [batch, channels, time, height, width]
# Output: [batch, 32, depth, height, width]

When to use: Video action recognition, CT/MRI analysis, volumetric data processing, spatiotemporal modeling.

Use cases:

  • Video action recognition (3D CNN over time dimension)
  • Medical image analysis (CT/MRI volume segmentation)
  • 3D object recognition (point clouds, voxel grids)
  • Temporal action detection

Hyperparameters:

  • kernel_size: Can be tuple (depth, height, width) or single int
  • stride: Downsampling in each dimension
  • padding: 3D padding

Models using it: C3D, I3D, 3D ResNet, 3D U-Net

Related terms: Spatiotemporal convolution, volumetric convolution

Papers: "Learning Spatiotemporal Features with 3D CNNs" (Tran et al., 2015 - C3D)

Dilated (Atrous) Convolution

What it does: Expands receptive field without increasing parameters via spacing. Kernel elements have gaps (dilation).

The math:

Standard convolution: y[i] = Σⱼ x[i+j] · w[j]
Dilated convolution: y[i] = Σⱼ x[i + j·r] · w[j]  where r is dilation rate
Receptive field: k + (k-1)(r-1) instead of k
python
nn.Conv2d(64, 128, kernel_size=3, dilation=2, padding=2)
# dilation=1: standard 3×3 (receptive field = 3)
# dilation=2: 5×5 receptive field with 3×3 parameters
# dilation=4: 9×9 receptive field with 3×3 parameters

When to use: Dense prediction (segmentation), audio generation (WaveNet), large receptive fields without pooling.

Use cases:

Hyperparameters:

  • dilation: Spacing between kernel elements (1 = standard, 2, 4, 8... for exponential growth)
  • padding: Adjust to maintain output size (padding = dilation × (kernel_size - 1) / 2)

Models using it: DeepLab, WaveNet, PixelCNN

Related terms: Atrous convolution, dilated receptive field

Papers: "Multi-Scale Context Aggregation by Dilated Convolutions" (Yu & Koltun, 2016)

Transposed Convolution (Deconvolution)

What it does: Upsamples spatial dimensions (learnable upsampling). Reverse of convolution in terms of shape transformation.

The math:

Standard conv: x[H,W] → y[H/s, W/s]  (downsample by stride s)
Transposed conv: x[H,W] → y[H×s, W×s]  (upsample by stride s)
Learnable upsampling (not true mathematical inverse)
python
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
# Input: [batch, 128, H, W]
# Output: [batch, 64, 2H, 2W] (doubles spatial dimensions)

When to use: GANs, autoencoders, semantic segmentation, upsampling decoder.

Use cases:

Hyperparameters:

  • kernel_size: Size of kernel (4×4 common for 2× upsampling)
  • stride: Upsampling factor (2 = double size)
  • padding, output_padding: Control exact output dimensions

Models using it: DCGAN, U-Net, FCN (fully convolutional networks)

Related terms: Deconvolution, upsampling, fractional-strided convolution

Note: Name "deconvolution" is misnomer - it's not true mathematical deconvolution, just learnable upsampling.

Depthwise Separable Convolution

What it does: Splits into depthwise (per-channel) + pointwise (1×1) for efficiency.

python
nn.Conv2d(64, 64, 3, groups=64)  # Depthwise
nn.Conv2d(64, 128, 1)             # Pointwise

When to use: Mobile/edge models (MobileNet, EfficientNet). Papers: "MobileNets" (Howard et al., 2017).

Grouped Convolution

What it does: Splits channels into groups, convolves independently.

python
nn.Conv2d(64, 128, kernel_size=3, groups=4)

When to use: Reducing parameters (ResNeXt, ShuffleNet).

Deformable Convolution

What it does: Learns spatial offsets for adaptive receptive fields.

python
# From torchvision.ops or custom
DeformConv2d(64, 128, kernel_size=3)

When to use: Object detection with geometric variation. Papers: "Deformable Convolutional Networks" (Dai et al., 2017).

Octave Convolution

What it does: Processes high/low spatial frequencies separately.

python
# Custom implementation

When to use: Efficient multi-scale feature extraction. Papers: "Drop an Octave" (Chen et al., 2019).


Recurrent & Temporal Architectures

Vanilla RNN

What it does: Processes sequences with hidden state feedback.

python
nn.RNN(input_size=128, hidden_size=256, num_layers=2)

When to use: Simple sequence modeling (rarely used now). Issues: Vanishing gradients.

LSTM (Long Short-Term Memory)

What it does: Gated RNN with cell state for long-term dependencies.

python
nn.LSTM(input_size=128, hidden_size=256, num_layers=2)

When to use: Time-series, NLP (pre-Transformer era). Papers: "Long Short-Term Memory" (Hochreiter & Schmidhuber, 1997).

GRU (Gated Recurrent Unit)

What it does: Simplified LSTM with fewer gates.

python
nn.GRU(input_size=128, hidden_size=256)

When to use: Faster alternative to LSTM.

Bidirectional RNN/LSTM/GRU

What it does: Processes sequence forward and backward.

python
nn.LSTM(128, 256, bidirectional=True)

When to use: When future context is available (NLP tagging).

Temporal Convolutional Network (TCN)

What it does: 1D dilated convolutions for sequence modeling.

python
# Stack of dilated Conv1d layers

When to use: Alternative to RNNs; parallelizable. Papers: "An Empirical Evaluation of Generic Convolutional and Recurrent Networks" (Bai et al., 2018).

State Space Models (S4, Mamba)

What it does: Linear state-space models for efficient long-range dependencies.

python
# Mamba/S4 implementation

When to use: Ultra-long sequences (genomics, audio). Papers: "Efficiently Modeling Long Sequences with Structured State Spaces" (Gu et al., 2022).


Attention Mechanisms

Self-Attention

What it does: Computes attention over the same sequence. Each position attends to all positions in the same sequence. The math:

Attention(Q,K,V) = softmax(QK^T / √d_k)V
where Q = K = V = xW (same input transformed)
python
Q = K = V = x  # Same sequence
attn_scores = (Q @ K.T) / math.sqrt(d_k)
attn_weights = F.softmax(attn_scores, dim=-1)
output = attn_weights @ V

When to use: Transformers, sequence modeling where context from all positions is needed. Use cases: BERT encoding, GPT generation, ViT image patches, protein sequence modeling. Hyperparameters:

  • d_k: Key/query dimension (typically 64 per head)
  • dropout: Attention dropout rate (typically 0.1) Models using it: BERT, GPT series, T5, ViT, CLIP.

Cross-Attention (Encoder-Decoder Attention)

What it does: Attends from one sequence (decoder) to another sequence (encoder). Queries come from decoder, keys and values from encoder. The math:

Attention(Q,K,V) = softmax(QK^T / √d_k)V
where Q = decoder_hidden, K = V = encoder_output
python
Q = decoder_hidden  # From decoder
K = V = encoder_output  # From encoder
attn_scores = (Q @ K.T) / math.sqrt(d_k)
attn_weights = F.softmax(attn_scores, dim=-1)
output = attn_weights @ V

When to use: Encoder-decoder architectures, translation, image captioning where decoder needs to attend to encoder context. Use cases: Machine translation (English→French), image captioning (image→text), speech recognition (audio→text), document summarization. Hyperparameters:

  • d_k: Key/query dimension
  • num_heads: Number of attention heads (8-16 typical)
  • dropout: Attention dropout Models using it: Original Transformer, BART, T5, Whisper (speech), BLIP (vision-language).

Multi-Head Attention (MHA)

What it does: Runs multiple self-attention operations in parallel, each learning different representation subspaces. Concatenates outputs. The math:

MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8, dropout=0.1):
        super().__init__()
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.o_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

When to use: All modern Transformers for capturing multiple types of relationships. Use cases: Any transformer model, enables learning of different attention patterns (syntactic, semantic, positional). Hyperparameters:

  • num_heads: Number of attention heads (8 for base, 16-32 for large models)
  • d_model: Total model dimension (must be divisible by num_heads)
  • dropout: Attention dropout (0.1 typical) Models using it: BERT (12 heads), GPT-3 (96 heads), ViT (12-16 heads), LLaMA (32-64 heads). Papers: "Attention Is All You Need" (Vaswani et al., 2017).

Bidirectional Attention

What it does: Attention where each token can attend to both past and future tokens (no causal masking). Used in encoders. The math: Same as self-attention but without causal mask.

python
# No masking - all positions can attend to all positions
attn_scores = (Q @ K.T) / math.sqrt(d_k)
attn_weights = F.softmax(attn_scores, dim=-1)  # Full attention matrix
output = attn_weights @ V

When to use: Encoding tasks where full context is available (classification, understanding, not generation). Use cases: BERT-style pre-training, text classification, named entity recognition, sentiment analysis, encoder in translation. Hyperparameters:

  • Same as Multi-Head Attention Models using it: BERT, RoBERTa, ALBERT, DeBERTa, encoder-only models.

Causal (Autoregressive) Attention

What it does: Attention with causal masking where token i can only attend to positions ≤ i. Prevents looking ahead during generation. The math:

mask[i,j] = -∞ if j > i else 0
attn = softmax((QK^T / √d_k) + mask)V
python
# Create causal mask
seq_len = Q.size(1)
mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)
attn_scores = (Q @ K.T) / math.sqrt(d_k) + mask
attn_weights = F.softmax(attn_scores, dim=-1)
output = attn_weights @ V

When to use: Autoregressive generation tasks (language models, decoders). Use cases: GPT-style text generation, code completion, dialogue, any left-to-right generation. Hyperparameters:

  • Same as Multi-Head Attention
  • Causal mask is binary, not a hyperparameter Models using it: GPT series, decoder in Transformer, LLaMA, PaLM, Falcon.

Rotary Position Embedding (RoPE) Attention

What it does: Applies rotation to query/key based on position. Encodes relative position directly in attention computation via rotation matrices. The math:

RoPE(x, m) = [x_0, x_1, ..., x_{d-1}] rotated by angles θ_i * m
where θ_i = 10000^(-2i/d) and m is position
python
def apply_rotary_emb(x, positions):
    # x: [batch, seq_len, d_model]
    d = x.size(-1)
    theta = 1.0 / (10000 ** (torch.arange(0, d, 2).float() / d))
    pos_theta = positions.unsqueeze(-1) * theta
    
    # Apply rotation
    cos = torch.cos(pos_theta)
    sin = torch.sin(pos_theta)
    # Rotate even/odd dimensions
    x_rot = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1).flatten(-2)
    return x * cos + x_rot * sin

When to use: When you need better length extrapolation than learned positional embeddings. Use cases: Long context models, models trained on short sequences but deployed on longer ones. Hyperparameters:

  • theta_base: Base for angle computation (default: 10000)
  • max_position: Maximum position for precomputation Models using it: GPT-NeoX, LLaMA, PaLM, Falcon, CodeGen. Papers: "RoFormer: Enhanced Transformer with Rotary Position Embedding" (Su et al., 2021).

Grouped-Query Attention (GQA)

What it does: Shares keys/values across query groups for efficiency. Middle ground between Multi-Head and Multi-Query attention. The math:

num_kv_heads < num_query_heads
Each KV head is shared by multiple query heads
Memory: num_kv_heads * d_k instead of num_query_heads * d_k
python
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model=512, num_q_heads=32, num_kv_heads=8):
        super().__init__()
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.group_size = num_q_heads // num_kv_heads
        # Fewer KV projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model * num_kv_heads // num_q_heads)
        self.v_proj = nn.Linear(d_model, d_model * num_kv_heads // num_q_heads)

When to use: Large language models where KV cache memory is bottleneck during inference. Use cases: Inference optimization for LLMs, reducing memory without losing much quality. Hyperparameters:

  • num_query_heads: Number of query heads (e.g., 32)
  • num_kv_heads: Number of KV heads (e.g., 8, divides num_query_heads)
  • Typical ratios: 4:1 or 8:1 (queries:kv) Models using it: LLaMA 2 (8 KV heads, 32 query heads), Mistral 7B. Papers: "GQA: Training Generalized Multi-Query Transformer" (Ainslie et al., 2023).

Multi-Query Attention (MQA)

What it does: Single key/value head shared across all query heads. Extreme version of GQA. The math:

num_kv_heads = 1
All query heads share same K, V
Saves (num_heads - 1) * d_k memory
python
class MultiQueryAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.q_proj = nn.Linear(d_model, d_model)
        # Single KV head
        self.k_proj = nn.Linear(d_model, d_model // num_heads)
        self.v_proj = nn.Linear(d_model, d_model // num_heads)

When to use: Extreme inference speed optimization, when KV cache memory is critical bottleneck. Use cases: Fast inference on edge devices, serving many users with limited memory. Hyperparameters:

  • num_query_heads: Number of query heads (8-32)
  • Always 1 KV head Models using it: PaLM, Falcon, StarCoder, early GPT-J variants. Papers: "Fast Transformer Decoding" (Shazeer, 2019).

FlashAttention

What it does: IO-aware exact attention algorithm that reduces memory reads/writes. Computes attention in blocks, fusing operations. The math: Same as standard attention but with block-wise computation and online softmax.

python
# Conceptual - actual implementation is in CUDA
# from flash_attn import flash_attn_func
# output = flash_attn_func(q, k, v, dropout_p=0.1, causal=True)

# Block-wise computation reduces HBM reads from O(N²) to O(N)
# Fuses softmax + matmul to reduce memory traffic

When to use: Training large transformers efficiently, when memory bandwidth is bottleneck. Use cases: Training GPT-3 scale models, long context models, reducing training time by 2-4×. Hyperparameters:

  • block_size: Tile size for block computation (tuned to SRAM size)
  • causal: Whether to apply causal masking
  • dropout_p: Attention dropout probability Models using it: GPT-4, LLaMA training, any large transformer training. Performance: 2-4× faster training, enables 64K+ context lengths. Papers: "FlashAttention: Fast and Memory-Efficient Exact Attention" (Dao et al., 2022).

Sparse Attention

What it does: Attends to subset of tokens using fixed patterns (not all O(n²) pairs). Reduces complexity for long sequences. The math:

Attention computed only for positions in sparse pattern
Complexity: O(n√n) or O(n log n) instead of O(n²)
python
# Longformer pattern: local window + global tokens
def create_sparse_mask(seq_len, window_size=256, global_tokens=[0]):
    mask = torch.zeros(seq_len, seq_len)
    # Local window
    for i in range(seq_len):
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        mask[i, start:end] = 1
    # Global tokens attend everywhere
    for g in global_tokens:
        mask[g, :] = 1
        mask[:, g] = 1
    return mask

When to use: Very long sequences (documents >4K tokens, genomes, code files). Use cases: Document understanding, genomic sequences, long-form generation, legal document analysis. Hyperparameters:

  • window_size: Local attention window (256-512 typical)
  • global_tokens: Positions with full attention ([CLS], special tokens)
  • pattern: Strided, fixed, random, or learned Models using it: Longformer, BigBird, Sparse Transformer. Papers: "Generating Long Sequences with Sparse Transformers" (Child et al., 2019).

Linear Attention

What it does: Approximates attention in O(n) time using kernel tricks. Avoids computing full n×n attention matrix. The math:

Standard: Attention(Q,K,V) = softmax(QK^T)V  [O(n²)]
Linear: Attention(Q,K,V) = φ(Q)(φ(K)^T V)  [O(n)]
where φ is kernel function (e.g., elu+1)
python
def linear_attention(q, k, v, eps=1e-6):
    # Apply kernel function
    q = F.elu(q) + 1  # Ensure positive
    k = F.elu(k) + 1
    
    # Compute in linear order: K^T V first, then Q(K^T V)
    kv = torch.einsum('...nd,...ne->...de', k, v)  # [d, e]
    qkv = torch.einsum('...nd,...de->...ne', q, kv)  # [n, e]
    
    # Normalize
    k_sum = k.sum(dim=-2, keepdim=True)
    qk = torch.einsum('...nd,...nd->...n', q, k_sum)
    return qkv / (qk.unsqueeze(-1) + eps)

When to use: Very long sequences where O(n²) is prohibitive, streaming applications. Use cases: Processing millions of tokens, online/streaming inference, efficient transformers. Hyperparameters:

  • kernel_function: elu+1, relu, or learned kernel
  • epsilon: Numerical stability term (1e-6) Models using it: Performer, Linear Transformer, LUNA. Papers: "Rethinking Attention with Performers" (Choromanski et al., 2021).

Oscillating Attention (Bidirectional-Forward-Backward)

What it does: Alternates between bidirectional, forward (causal), and backward attention patterns. Provides varied context integration. The math:

Layer 1: Bidirectional (all-to-all)
Layer 2: Forward only (causal)
Layer 3: Backward only (reverse causal)
Pattern repeats or cycles
python
class OscillatingAttention(nn.Module):
    def __init__(self, num_layers=12):
        super().__init__()
        self.attention_patterns = []
        for i in range(num_layers):
            pattern = ['bidirectional', 'forward', 'backward'][i % 3]
            self.attention_patterns.append(pattern)
    
    def get_mask(self, seq_len, pattern):
        if pattern == 'bidirectional':
            return torch.zeros(seq_len, seq_len)  # No masking
        elif pattern == 'forward':
            return torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)
        else:  # backward
            return torch.tril(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=-1)

When to use: Experimental architectures exploring diverse attention patterns, when different context types benefit different layers. Use cases: Research into attention pattern diversity, potential for improved bidirectional understanding. Hyperparameters:

  • pattern_cycle: Sequence of attention types per layer
  • cycle_length: How many layers before repeating pattern (typically 3) Models using it: Experimental research architectures, some recent language model variants. Note: Less common than other patterns, but useful for architecture search.

Positional Encodings

Sinusoidal (Transformer)

python
PE(pos, 2i) = sin(pos / 10000^(2i/d))

Rotary (RoPE)

python
# Rotates query/key by position-dependent angle

Papers: "RoFormer" (Su et al., 2021).

ALiBi (Attention with Linear Biases)

python
# Adds position-dependent bias to attention scores

Papers: "Train Short, Test Long" (Press et al., 2022).

Learnable Positional Embeddings

python
nn.Embedding(max_seq_len, embed_dim)

Sliding Window Attention

What it does: Each token attends only to a fixed-size window around it. The math: For token i, attention window is [i-w, i+w] where w is window size.

python
def sliding_window_attention(q, k, v, window_size=256):
    # Create attention mask for local window
    seq_len = q.size(1)
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        mask[i, start:end] = 1
    
    attn = torch.softmax(q @ k.T / math.sqrt(d_k) + (1 - mask) * -1e9, dim=-1)
    return attn @ v

When to use: Long documents, reducing O(n²) cost while preserving local context. Models using it: Longformer, BigBird (combined with global attention). Papers: "Longformer: The Long-Document Transformer" (Beltagy et al., 2020).

Global + Local Attention (Hybrid)

What it does: Some tokens get global attention while most use local/sliding window.

python
# Longformer pattern:
# - Special tokens (CLS, SEP) attend globally
# - Content tokens use sliding window
global_tokens = [0]  # CLS token
for i in range(seq_len):
    if i in global_tokens:
        attend_to = list(range(seq_len))  # Global
    else:
        attend_to = sliding_window(i, window_size)  # Local

When to use: Document understanding, where some tokens need full context. Models using it: Longformer, LED (Longformer Encoder-Decoder).

Dilated Attention

What it does: Attention with gaps, attending to every k-th token.

python
# Attend to positions: i, i+k, i+2k, i+3k, ...
# Similar to dilated convolutions but for attention

When to use: Capturing patterns at multiple scales without full O(n²) cost. Models using it: Some BigBird variants.

Random Attention (BigBird)

What it does: Each token attends to random subset of tokens for global connectivity.

python
# Combine: local window + random tokens + global tokens
import random
def bigbird_attention_pattern(seq_len, window_size=3, num_random=3, num_global=2):
    patterns = []
    for i in range(seq_len):
        attend_to = set()
        # Local window
        attend_to.update(range(max(0, i-window_size), min(seq_len, i+window_size+1)))
        # Random tokens
        attend_to.update(random.sample(range(seq_len), num_random))
        # Global tokens
        attend_to.update(range(num_global))
        patterns.append(sorted(attend_to))
    return patterns

When to use: Very long sequences with sparse dependencies. Models using it: BigBird for genomics, long documents. Papers: "Big Bird: Transformers for Longer Sequences" (Zaheer et al., 2020).

Axial Attention

What it does: Factorizes 2D attention into row-wise and column-wise attention.

python
# For image patches arranged in H×W grid:
# 1. Attend within each row (W positions)
# 2. Attend within each column (H positions)
# Cost: O(H×W×(H+W)) instead of O((H×W)²)

When to use: High-resolution images, video, reducing attention cost for 2D inputs. Models using it: Axial-DeepLab for segmentation. Papers: "Axial Attention in Multidimensional Transformers" (Ho et al., 2019).

Perceiver / Cross-Attention to Latents

What it does: Large input attends to small set of learned latent vectors.

python
class PerceiverAttention(nn.Module):
    def __init__(self, num_latents=512, latent_dim=512, input_dim=1024):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
        self.cross_attn = CrossAttention(latent_dim, input_dim)
    
    def forward(self, x):
        # x shape: [batch, seq_len, input_dim]
        # latents attend to input
        return self.cross_attn(self.latents, x)

When to use: Processing very large inputs (images, point clouds, multimodal) with fixed compute. Models using it: Perceiver, Perceiver IO for arbitrary inputs/outputs. Papers: "Perceiver: General Perception with Iterative Attention" (Jaegle et al., 2021).

Cached Attention (KV Cache)

What it does: Stores previous key/value pairs for autoregressive generation.

python
class CachedAttention:
    def __init__(self):
        self.cache_k = []
        self.cache_v = []
    
    def forward(self, q, k, v):
        # Append new k, v to cache
        self.cache_k.append(k)
        self.cache_v.append(v)
        k_full = torch.cat(self.cache_k, dim=1)
        v_full = torch.cat(self.cache_v, dim=1)
        return attention(q, k_full, v_full)

When to use: Text generation, reducing redundant computation during decoding. Models using it: All autoregressive Transformers (GPT, LLaMA) during inference.

Retentive Networks (RetNet)

What it does: Parallel training, recurrent inference via retention mechanism (alternative to attention). The math:

Retention: O_t = (Q_t K^T) ⊙ D_t V
where D_t is causal decay matrix
python
# Dual form: acts like attention (parallel) or RNN (sequential)

When to use: When you want Transformer quality with RNN efficiency at inference. Models using it: RetNet as GPT alternative. Papers: "Retentive Network: A Successor to Transformer" (Sun et al., 2023).

RWKV (Receptance Weighted Key Value)

What it does: RNN-like architecture with attention-like mixing, linear complexity.

python
# Time-mixing combines current and previous states
# Channel-mixing is position-wise FFN

When to use: Long context with constant memory, streaming inference. Models using it: RWKV language models (up to 14B parameters). Papers: RWKV architecture (BlinkDL, 2023).


Normalization Techniques

Batch Normalization (BatchNorm)

What it does: Normalizes activations across batch dimension. Reduces internal covariate shift. The math:

μ_B = (1/m)Σx_i  # Batch mean
σ²_B = (1/m)Σ(x_i - μ_B)²  # Batch variance
x̂ = (x - μ_B) / √(σ²_B + ε)  # Normalize
y = γx̂ + β  # Scale and shift (learnable)
python
nn.BatchNorm2d(num_features=64, eps=1e-5, momentum=0.1, affine=True)

When to use: CNNs with large batches (>16), most computer vision models. Use cases: ResNet, VGG, most CNNs, helps training stability and speed. Hyperparameters:

  • eps: Epsilon for numerical stability (1e-5 default, 1e-3 for FP16)
  • momentum: Running statistics momentum (0.1 default, higher = more stable)
  • affine: Learn γ and β (True default) Issues: Batch size dependent (fails with batch=1), different behavior train/test. Models using it: ResNet, DenseNet, VGG, Inception, most CNNs. Papers: "Batch Normalization: Accelerating Deep Network Training" (Ioffe & Szegedy, 2015).

Layer Normalization (LayerNorm)

What it does: Normalizes across feature dimension per sample. Batch-independent. The math:

μ = (1/d)Σx_i  # Mean across features (per sample)
σ² = (1/d)Σ(x_i - μ)²  # Variance across features
x̂ = (x - μ) / √(σ² + ε)
y = γx̂ + β  # Scale and shift
python
nn.LayerNorm(normalized_shape=512, eps=1e-5, elementwise_affine=True)

When to use: Transformers, RNNs, any model with small/variable batch sizes. Use cases: BERT, GPT, T5, all transformers, RNNs/LSTMs, online learning. Hyperparameters:

  • normalized_shape: Shape to normalize over (often [d_model])
  • eps: Numerical stability (1e-5 default, 1e-6 for transformers)
  • elementwise_affine: Learn γ and β (True default) Models using it: All Transformers (BERT, GPT, T5), LSTMs, GRUs. Papers: "Layer Normalization" (Ba et al., 2016).

Instance Normalization

What it does: Normalizes per sample, per channel. Each instance and channel normalized independently. The math: Same as BatchNorm but statistics computed per instance, per channel

python
nn.InstanceNorm2d(num_features=64, eps=1e-5, affine=False)

When to use: Style transfer, GANs, when instance appearance should be normalized independently. Use cases: Neural style transfer, CycleGAN, image-to-image translation. Hyperparameters:

  • eps: Numerical stability (1e-5 default)
  • affine: Usually False for style transfer Models using it: Style transfer networks, CycleGAN, Pix2Pix, artistic style models. Papers: "Instance Normalization" (Ulyanov et al., 2016).

Group Normalization

What it does: Divides channels into groups, normalizes within each group. Batch-independent like LayerNorm. The math:

Split C channels into G groups
Normalize within each group (like LayerNorm per group)
python
nn.GroupNorm(num_groups=8, num_channels=64, eps=1e-5, affine=True)

When to use: Small batch sizes, video models, when BatchNorm fails. Use cases: Object detection (Mask R-CNN), video understanding, batch size 1-2. Hyperparameters:

  • num_groups: Number of groups (8 or 32 typical, must divide num_channels)
  • eps: Numerical stability (1e-5 default) Special cases: num_groups=1 → LayerNorm, num_groups=num_channels → InstanceNorm Models using it: Mask R-CNN, video models, models with small batches. Papers: "Group Normalization" (Wu & He, 2018).

RMS Normalization (RMSNorm)

What it does: Simpler than LayerNorm, no mean centering. Just rescales by RMS. The math:

RMS(x) = √((1/d)Σx²)
y = x / (RMS(x) + ε) · γ  # No mean subtraction, no β
python
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

When to use: Large language models, when you want faster than LayerNorm. Use cases: LLaMA, modern LLMs prioritizing speed, efficient transformers. Hyperparameters:

  • eps: Numerical stability (1e-6 typical for LLMs) Advantage: ~10-20% faster than LayerNorm, fewer operations. Models using it: LLaMA, Gopher, Chinchilla, modern efficient LLMs. Papers: "Root Mean Square Layer Normalization" (Zhang & Sennrich, 2019).

Weight Normalization

What it does: Reparameterizes weight matrix as direction × magnitude. Normalizes weight vectors. The math:

w = g · v / ||v||  # g is magnitude (scalar), v is direction (vector)
python
layer = nn.Linear(512, 256)
layer = nn.utils.weight_norm(layer, name='weight', dim=0)

When to use: GANs, reinforcement learning, alternative to BatchNorm. Use cases: GAN generators, RL policy networks, when batch normalization isn't suitable. Models using it: Some GAN variants, RL algorithms. Papers: "Weight Normalization" (Salimans & Kingma, 2016).

Spectral Normalization

What it does: Constrains spectral norm (largest singular value) of weight matrices. Lipschitz constraint. The math:

W_SN = W / σ(W)  # σ(W) is largest singular value
python
layer = nn.Linear(512, 256)
layer = nn.utils.spectral_norm(layer, n_power_iterations=1)

When to use: GAN discriminators for training stabilization. Use cases: SNGAN, BigGAN, improving GAN training stability. Hyperparameters:

  • n_power_iterations: Iterations to compute spectral norm (1 default, more = accurate) Models using it: Spectral Normalization GAN (SNGAN), BigGAN, stable GAN training. Papers: "Spectral Normalization for GANs" (Miyato et al., 2018).

Activation Functions

ReLU (Rectified Linear Unit)

python
nn.ReLU()  # max(0, x)

When to use: Default for most networks.

Leaky ReLU

python
nn.LeakyReLU(negative_slope=0.01)

When to use: Prevent dead neurons.

PReLU (Parametric ReLU)

What it does: ReLU with learnable negative slope per channel. The math: f(x) = max(0, x) + a·min(0, x) where a is learned

python
nn.PReLU(num_parameters=1, init=0.25)  # num_parameters: 1 (shared) or num_channels

When to use: When you want to learn the negative slope instead of fixing it. Use cases: Deep CNNs, when Leaky ReLU helps but optimal slope is unknown. Hyperparameters:

  • num_parameters: 1 (single α for all channels) or C (per-channel α)
  • init: Initial value for α (0.25 default) Models using it: Some ResNet variants, image classification models.

ELU (Exponential Linear Unit)

What it does: Smooth activation with negative values, pushes mean activations closer to zero. The math: f(x) = x if x > 0 else α(e^x - 1)

python
nn.ELU(alpha=1.0, inplace=False)

When to use: Want smoother gradients than ReLU, especially for deep networks. Use cases: Deep CNNs, networks with many layers, when ReLU causes dead neurons. Hyperparameters:

  • alpha: Scaling for negative values (1.0 default, controls negative saturation) Models using it: Some deep CNNs, experimental architectures. Papers: "Fast and Accurate Deep Network Learning by ELU" (Clevert et al., 2016).

SELU (Scaled ELU)

What it does: Self-normalizing activation that maintains mean≈0 and variance≈1 through layers. Special case of ELU. The math: f(x) = λx if x > 0 else λα(e^x - 1) where λ≈1.0507, α≈1.6733

python
nn.SELU(inplace=False)  # Fixed λ and α for self-normalization

When to use: Self-normalizing networks (SNNs), when you want to avoid Batch Normalization. Use cases: Deep MLPs, FNNs where batch norm is problematic (small batches, online learning). Requirements: Must use with lecun_normal weight initialization and Alpha Dropout. Models using it: Self-normalizing neural networks for tabular data. Papers: "Self-Normalizing Neural Networks" (Klambauer et al., 2017).

GELU (Gaussian Error Linear Unit)

What it does: Smooth, non-monotonic activation combining properties of ReLU and dropout. Weighs input by its percentile. The math: f(x) = x·Φ(x) where Φ(x) is cumulative distribution function of standard normal

python
nn.GELU(approximate='none')  # 'none' for exact, 'tanh' for fast approximation

Approximation: f(x) ≈ 0.5x(1 + tanh[√(2/π)(x + 0.044715x³)]) When to use: Transformers, modern NLP models, vision transformers. Use cases: BERT, GPT, ViT, any transformer-based model. Hyperparameters:

  • approximate: 'none' (exact but slower) or 'tanh' (fast approximation) Models using it: BERT, GPT, T5, ViT, most transformers. Papers: "Gaussian Error Linear Units" (Hendrycks & Gimpel, 2016).

SiLU / Swish

What it does: Smooth, non-monotonic activation. Self-gated version of identity (x·σ(x)). The math: f(x) = x·sigmoid(βx) where β=1 for SiLU/Swish-1

python
nn.SiLU(inplace=False)  # Same as Swish with β=1
# Or manually: x * torch.sigmoid(x)

When to use: Modern CNNs, MobileNets, EfficientNet, better than ReLU in many vision tasks. Use cases: EfficientNet, MobileNetV3, modern CNN architectures, mobile/edge models. Hyperparameters:

  • β: Sigmoid scaling (1.0 for SiLU, learnable in original Swish) Models using it: EfficientNet, MobileNetV3, YOLOv5, many modern CNNs. Papers: "Searching for Activation Functions" (Ramachandran et al., 2017).

Mish

What it does: Smooth, non-monotonic activation similar to Swish but slightly smoother. The math: f(x) = x·tanh(softplus(x)) = x·tanh(ln(1 + e^x))

python
class Mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

When to use: When you want smoother than ReLU/Swish, experimental improvements in CNNs. Use cases: Object detection, image classification, experimental architectures. Models using it: YOLOv4, some experimental CNN variants. Papers: "Mish: A Self Regularized Non-Monotonic Activation Function" (Misra, 2019).

GLU (Gated Linear Unit)

What it does: Gating mechanism that allows network to control information flow. Splits input and gates one half with sigmoid of other. The math: GLU(x) = (xW + b) ⊗ σ(xV + c)

python
class GLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
        self.gate = nn.Linear(dim, dim)
    
    def forward(self, x):
        return self.linear(x) * torch.sigmoid(self.gate(x))

When to use: Language models, Transformers FFN, when gating helps control information flow. Use cases: GPT-2 (in FFN with variants), language modeling, gated networks. Variants:

  • ReGLU: Gates with ReLU (linear(x) * relu(gate(x)))
  • SwiGLU: Gates with SiLU (used in LLaMA, PaLM)
  • GeGLU: Gates with GELU Models using it: GPT-2/3 variants, LLaMA (SwiGLU), PaLM (SwiGLU). Papers: "Language Modeling with Gated Convolutional Networks" (Dauphin et al., 2017).*

Softmax

What it does: Converts logits to probability distribution. Exponentiates and normalizes. The math: softmax(x_i) = e^(x_i) / Σ_j e^(x_j)

python
nn.Softmax(dim=-1)  # Normalize over specified dimension
F.softmax(logits, dim=-1)  # Functional version

When to use: Multi-class classification output layer, attention weights, any probability distribution over classes. Use cases: Classification head, attention mechanisms, language model token probabilities. Hyperparameters:

  • dim: Dimension to normalize over (-1 for last dimension) Temperature scaling: softmax(x/T) where T controls sharpness (T<1: sharper, T>1: smoother) Models using it: Every classification model, all transformers (in attention). Numerical stability: Use log-softmax when computing cross-entropy to avoid overflow.

Sigmoid

python
nn.Sigmoid()

When to use: Binary classification, gates.

Tanh

python
nn.Tanh()

When to use: RNNs, range [-1, 1].

Hardswish / Hardsigmoid

python
nn.Hardswish()

When to use: Mobile models (MobileNetV3).


Pooling & Downsampling

Max Pooling

python
nn.MaxPool2d(kernel_size=2, stride=2)

When to use: Spatial downsampling, CNNs.

Average Pooling

python
nn.AvgPool2d(kernel_size=2)

When to use: Smoother downsampling.

Global Average Pooling (GAP)

python
nn.AdaptiveAvgPool2d(output_size=1)

When to use: Replace FC layers (ResNet, NiN).

Global Max Pooling

python
nn.AdaptiveMaxPool2d(output_size=1)

Adaptive Pooling

python
nn.AdaptiveAvgPool2d(output_size=(7, 7))

When to use: Variable input sizes.

Stochastic Pooling

What it does: Randomly samples from pooling region. When to use: Regularization during training.

Fractional Max Pooling

What it does: Non-integer downsampling ratios.

python
nn.FractionalMaxPool2d(kernel_size=2, output_ratio=0.5)

Optimization Algorithms

SGD (Stochastic Gradient Descent)

What it does: Updates parameters in direction of negative gradient. With momentum, accumulates velocity. The math:

Standard SGD: θ_t = θ_{t-1} - α∇L(θ_{t-1})
With Momentum: v_t = μv_{t-1} + ∇L(θ_{t-1})
               θ_t = θ_{t-1} - αv_t
Nesterov: v_t = μv_{t-1} + ∇L(θ_{t-1} - αμv_{t-1})
python
optim.SGD(params, lr=0.01, momentum=0.9, nesterov=True, weight_decay=1e-4)

When to use: Simple baseline, large-batch training, CNNs with batch normalization. Use cases: Training ResNets on ImageNet, simple baselines, when Adam over-fits. Hyperparameters:

  • lr: Learning rate (0.001-0.1, often 0.01 for vision)
  • momentum: Momentum coefficient (0.9 typical, 0.99 for large batches)
  • nesterov: Use Nesterov momentum (True recommended)
  • weight_decay: L2 regularization (1e-4 to 1e-5) Models using it: ResNet (with momentum 0.9), VGG, AlexNet.

Adam (Adaptive Moment Estimation)

What it does: Adaptive learning rates per parameter using first and second moment estimates. Combines RMSprop and momentum. The math:

m_t = β₁m_{t-1} + (1-β₁)g_t          # First moment (momentum)
v_t = β₂v_{t-1} + (1-β₂)g_t²         # Second moment (variance)
m̂_t = m_t/(1-β₁ᵗ)                    # Bias correction
v̂_t = v_t/(1-β₂ᵗ)                    # Bias correction
θ_t = θ_{t-1} - α·m̂_t/(√v̂_t + ε)    # Update
python
optim.Adam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)

When to use: Default choice for most tasks, especially when SGD is unstable. Use cases: Training transformers, GANs, vision models, multi-task learning. Hyperparameters:

  • lr: Learning rate (1e-4 to 1e-3, often 1e-3)
  • betas: (β₁, β₂) momentum coefficients ((0.9, 0.999) default, (0.9, 0.98) for transformers)
  • eps: Epsilon for numerical stability (1e-8 default, 1e-6 for FP16)
  • weight_decay: Weight decay (0 for Adam, use AdamW instead) Models using it: Early GPT, BERT (before AdamW), VAE, many research models. Papers: "Adam: A Method for Stochastic Optimization" (Kingma & Ba, 2015).

AdamW (Adam with Decoupled Weight Decay)

What it does: Fixes Adam by decoupling weight decay from gradient-based update. Proper L2 regularization. The math:

m_t = β₁m_{t-1} + (1-β₁)g_t
v_t = β₂v_{t-1} + (1-β₂)g_t²
m̂_t = m_t/(1-β₁ᵗ)
v̂_t = v_t/(1-β₂ᵗ)
θ_t = θ_{t-1} - α·m̂_t/(√v̂_t + ε) - α·λ·θ_{t-1}  # Decoupled weight decay
python
optim.AdamW(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

When to use: Transformers, language models, any model where proper regularization matters. Use cases: BERT, GPT training, fine-tuning large models, modern deep learning default. Hyperparameters:

  • lr: Learning rate (1e-4 to 5e-4 for LLMs, 1e-3 for smaller models)
  • betas: (0.9, 0.999) default, (0.9, 0.98) for transformers
  • eps: 1e-8 default, 1e-6 or 1e-7 for mixed precision (FP16)
  • weight_decay: 0.01 to 0.1 (0.01 typical, higher for large models) Models using it: BERT, GPT-2/3, LLaMA, all modern transformers, ViT. Papers: "Decoupled Weight Decay Regularization" (Loshchilov & Hutter, 2019).

Adagrad (Adaptive Gradient)

What it does: Adapts learning rate per parameter based on historical gradient magnitudes. Good for sparse data. The math:

G_t = G_{t-1} + g_t²  # Accumulate squared gradients
θ_t = θ_{t-1} - α·g_t/(√G_t + ε)
python
optim.Adagrad(params, lr=0.01, eps=1e-10, weight_decay=0)

When to use: Sparse gradients (NLP with rare words, recommender systems). Use cases: Word embeddings with rare tokens, sparse feature learning, online learning. Hyperparameters:

  • lr: Learning rate (0.01 typical, higher than Adam)
  • eps: Epsilon for stability (1e-10 default)
  • weight_decay: L2 penalty (typically 0) Issue: Learning rate decays too aggressively, rarely used now. Replaced by: Adam, AdamW for most tasks.

Adadelta

What it does: Extension of Adagrad that uses moving window of gradients instead of accumulating all. No manual learning rate. The math:

E[g²]_t = ρE[g²]_{t-1} + (1-ρ)g_t²
RMS[g]_t = √(E[g²]_t + ε)
Δθ_t = -(RMS[Δθ]_{t-1}/RMS[g]_t)·g_t
θ_t = θ_{t-1} + Δθ_t
python
optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0)

When to use: When you want adaptive learning rate without tuning lr hyperparameter. Use cases: Experimentation when you don't want to tune learning rate. Hyperparameters:

  • lr: Learning rate scaling (1.0 default, usually keep as is)
  • rho: Decay rate for moving average (0.9 typical)
  • eps: Numerical stability (1e-6) Rarely used: Adam/AdamW generally perform better.

RMSprop (Root Mean Square Propagation)

What it does: Adapts learning rate using moving average of squared gradients. Like Adagrad but with decay. The math:

v_t = βv_{t-1} + (1-β)g_t²
θ_t = θ_{t-1} - α·g_t/(√v_t + ε)
python
optim.RMSprop(params, lr=1e-3, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0)

When to use: RNNs, non-stationary objectives, reinforcement learning. Use cases: Training RNNs/LSTMs, online learning, RL (DQN, A3C). Hyperparameters:

  • lr: Learning rate (1e-3 to 1e-2, often 1e-3)
  • alpha: Decay rate for squared gradient moving average (0.99 default)
  • eps: Numerical stability (1e-8, use 1e-6 for FP16)
  • momentum: Optional momentum term (0 default, 0.9 sometimes used) Models using it: Early RNN training, AlphaGo, DQN. Note: Adam has largely replaced it.

Lion

python
# from lion_pytorch import Lion
Lion(params, lr=1e-4)

When to use: More memory-efficient than Adam. Papers: "Symbolic Discovery of Optimization Algorithms" (Chen et al., 2023).

Sophia (Second-order)

python
# Approximates Hessian diagonal

When to use: LLM pre-training (2× faster than Adam). Papers: "Sophia" (Liu et al., 2023).

LAMB (Layer-wise Adaptive Moments)

What it does: Enables large-batch training (up to 64K batch size) without loss of accuracy by adapting learning rates per layer based on weight/gradient norms.

The Large-Batch Problem: Traditional optimizers (SGD, Adam) suffer from the generalization gap when using large batches:

  • Small batch (256): Good generalization, slow training
  • Large batch (32K+): Fast training, poor generalization
  • LAMB solves this by maintaining generalization while scaling to massive batches

Core Innovation: LAMB combines:

  1. Layer-wise adaptation (from LARS) - scales LR per layer
  2. Adaptive moments (from Adam) - per-parameter adaptive learning rates
  3. Trust ratio - prevents gradient explosion in large-batch settings

Mathematical Foundation:

Standard Adam Update:

python
m_t = β₁ * m_{t-1} + (1 - β₁) * g_t        # First moment (momentum)
v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²       # Second moment (variance)
m̂_t = m_t / (1 - β₁^t)                     # Bias correction
v̂_t = v_t / (1 - β₂^t)
θ_t = θ_{t-1} - α * m̂_t / (√v̂_t + ε)

LAMB Update (Layer-wise):

python
# For each layer l:
m_t^l = β₁ * m_{t-1}^l + (1 - β₁) * g_t^l
v_t^l = β₂ * v_{t-1}^l + (1 - β₂) * (g_t^l)²
m̂_t^l = m_t^l / (1 - β₁^t)
v̂_t^l = v_t^l / (1 - β₂^t)

# Compute adaptive update
r_t^l = m̂_t^l / (√v̂_t^l + ε) + λ * θ_{t-1}^l  # Add weight decay

# Layer-wise trust ratio (key innovation)
φ(||θ_{t-1}^l||) = min(max(||θ_{t-1}^l||, γ_l), γ_u)  # Clipping
ψ(||r_t^l||) = min(max(||r_t^l||, γ_l), γ_u)

trust_ratio = φ(||θ_{t-1}^l||) / ψ(||r_t^l||)

# Final update with trust ratio
θ_t^l = θ_{t-1}^l - α * trust_ratio * r_t^l

Why Trust Ratio Works:

  • Prevents large updates when gradients are noisy (large batches)
  • Scales learning rate based on weight norm / update norm
  • Ensures stable training even with batch sizes of 32K-64K

Implementation:

PyTorch (pytorch-lamb):

python
from pytorch_lamb import Lamb

model = YourModel()
optimizer = Lamb(
    model.parameters(),
    lr=0.002,              # Base learning rate
    betas=(0.9, 0.999),    # Adam-style moments
    eps=1e-6,
    weight_decay=0.01,     # Decoupled weight decay (like AdamW)
    adam=False,            # Use LAMB (not Adam)
    clamp_value=10.0       # Gradient clipping
)

# Training loop
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()
    optimizer.step()

Manual Implementation (Conceptual):

python
import torch

class LAMB(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 
                 weight_decay=0.01, adam=False):
        defaults = dict(lr=lr, betas=betas, eps=eps, 
                       weight_decay=weight_decay, adam=adam)
        super(LAMB, self).__init__(params, defaults)
    
    def step(self, closure=None):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad.data
                state = self.state[p]
                
                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1
                
                # Exponential moving averages
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                # Bias correction
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                # Adaptive update
                update = exp_avg / bias_correction1
                update = update / (exp_avg_sq / bias_correction2).sqrt().add_(group['eps'])
                
                # Add weight decay
                if group['weight_decay'] != 0:
                    update.add_(p.data, alpha=group['weight_decay'])
                
                # Compute trust ratio (LAMB-specific)
                weight_norm = p.data.norm()
                update_norm = update.norm()
                
                if weight_norm > 0 and update_norm > 0:
                    trust_ratio = weight_norm / update_norm
                else:
                    trust_ratio = 1.0
                
                # Apply update with trust ratio
                p.data.add_(update, alpha=-group['lr'] * trust_ratio)

When to Use LAMB:

Perfect for:

  • Large-batch training (batch size > 8K)
  • BERT/Transformer pre-training (original use case)
  • Distributed training across many GPUs/TPUs
  • Time-constrained training (need to finish quickly)
  • Cloud training (minimize cost via large batches)

Not ideal for:

  • Small batch sizes (<512) - use Adam/AdamW instead
  • Small models - overhead not worth it
  • Fine-tuning - Adam/AdamW work better
  • Computer vision (ResNet) - use LARS instead

Hyperparameter Recommendations:

BERT Pre-training (Original Paper):

python
batch_size = 65536          # 64K batch size!
lr = 0.00176                # Scaled from base 0.00176
warmup_steps = 10000        # Linear warmup
betas = (0.9, 0.999)
eps = 1e-6
weight_decay = 0.01
max_steps = 1000000

General Transformer Training:

python
batch_size = 8192-32768     # Scale as large as memory allows
lr = 0.002                  # Base LR (scale with batch size)
warmup_proportion = 0.1     # 10% of training for warmup
betas = (0.9, 0.999)
weight_decay = 0.01
gradient_clip = 1.0         # Clip gradients

Learning Rate Scaling Rule:

python
# Linear scaling with batch size
base_lr = 0.002
base_batch = 256
actual_batch = 32768

scaled_lr = base_lr * (actual_batch / base_batch)
# Example: 0.002 * (32768 / 256) = 0.256

Advantages:

Massive batch sizes (32K-64K) without accuracy loss
Faster training (fewer iterations, more parallelism)
Better scaling across distributed systems
Stable convergence even with noisy gradients
Works across architectures (BERT, RoBERTa, ELECTRA)

Limitations:

Complexity - more hyperparameters than Adam
Memory overhead - stores layer-wise statistics
Not universal - CV models prefer LARS
Warmup required - sensitive to initial learning rate
Overkill for small batches - Adam is simpler

LAMB vs. Other Optimizers:

OptimizerBest Batch SizeUse CaseComplexity
SGD32-256Baseline, simple tasksLow
Adam32-512General deep learningLow
AdamW32-1024Transformers (small batch)Low
LARS8K-32KResNet, computer visionMedium
LAMB16K-64KBERT, NLP transformersMedium

Comparison to LARS:

  • LARS: Layer-wise adaptation only (no adaptive moments)
    • Better for CNNs (ResNet, EfficientNet)
    • Simpler, less memory
  • LAMB: Layer-wise + adaptive moments
    • Better for Transformers (BERT, GPT)
    • More sophisticated, higher memory

Real-World Results (from Paper):

BERT-Large Pre-training:

  • Baseline (Adam, batch=256): 3 days on 64 TPUs
  • LAMB (batch=32K): 76 minutes on 1024 TPUs
  • Accuracy: Identical (no generalization gap!)

Key Insight: LAMB enables 50× speedup via massive batches without sacrificing quality.

Practical Tips:

  1. Start with warmup: Use 5-10% of training steps for linear LR warmup
  2. Scale LR linearly: lr = base_lr * (batch_size / base_batch)
  3. Use gradient clipping: Clip at 1.0 to prevent explosions
  4. Monitor layer norms: Check weight/gradient norms per layer
  5. Tune weight decay: 0.01 is standard, but try 0.001-0.1
  6. Batch size sweet spot: 8K-32K (beyond 64K shows diminishing returns)*

Advanced Techniques:

Mixed Precision + LAMB:

python
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
optimizer = Lamb(model.parameters(), lr=0.002)

for batch in dataloader:
    optimizer.zero_grad()
    
    with autocast():  # FP16 forward pass
        loss = model(batch)
    
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

Distributed Training (PyTorch DDP):

python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Initialize process group
dist.init_process_group(backend='nccl')

model = DDP(model, device_ids=[local_rank])
optimizer = Lamb(model.parameters(), lr=0.002)

# Gradient accumulation for effective large batch
accumulation_steps = 4
for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Papers & References:

Primary Paper:

  • "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes"
    You, Yang, et al. (ICLR 2020)
    • Introduces LAMB optimizer
    • Proves convergence guarantees
    • Demonstrates 64K batch BERT training

Related Work:

  • "Large Batch Training of Convolutional Networks" (LARS paper, Ginsburg et al., 2018)
  • "Accurate, Large Minibatch SGD" (Goyal et al., 2017) - Linear scaling rule
  • "Don't Decay the Learning Rate, Increase the Batch Size" (Smith et al., 2018)

Applications:

  • NLP Pre-training: BERT, RoBERTa, ELECTRA, T5
  • Large-scale transformers: GPT-style models (when batch size > 8K)
  • Multimodal models: CLIP, DALL-E (vision + language)
  • Speech models: Wav2Vec 2.0, HuBERT
  • Cloud/distributed training: Minimize wall-clock time

Debugging LAMB:

Common Issues:

  1. Loss diverges early:

    • Reduce LR (try 0.001 instead of 0.002)
    • Increase warmup steps
    • Check gradient clipping is enabled
  2. Slower than expected:

    • Ensure batch size is actually large (>8K)
    • Check GPU utilization (should be >90%)
    • Verify data loading isn't bottleneck
  3. Worse accuracy than Adam:

    • Increase training steps (large batches need more steps)
    • Tune weight decay (try 0.001, 0.01, 0.1)
    • Verify LR scaling is correct

Example: BERT Pre-training with LAMB

python
from transformers import BertForMaskedLM, BertConfig
from pytorch_lamb import Lamb

# Model
config = BertConfig(vocab_size=30522, hidden_size=768, num_hidden_layers=12)
model = BertForMaskedLM(config).cuda()

# LAMB optimizer
optimizer = Lamb(
    model.parameters(),
    lr=0.00176,
    betas=(0.9, 0.999),
    eps=1e-6,
    weight_decay=0.01
)

# Linear warmup scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10000,
    num_training_steps=1000000
)

# Training loop
for step, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    
    outputs = model(
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        labels=batch['labels']
    )
    
    loss = outputs.loss
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
    
    if step % 1000 == 0:
        print(f"Step {step}, Loss: {loss.item():.4f}")

Summary: LAMB is the go-to optimizer for large-batch transformer training. If you're pre-training BERT/GPT-style models on multiple GPUs/TPUs and want to minimize training time, LAMB enables batch sizes of 32K-64K without sacrificing accuracy. For smaller batches (<8K) or fine-tuning, stick with AdamW.

LARS (Layer-wise Adaptive Rate Scaling)

What it does: Scales learning rate per layer based on weight and gradient norms, enabling large-batch training for vision models.

The math:

For each layer l:
λ^l = η * ||w^l|| / (||∇L(w^l)|| + β||w^l||)
w^l_new = w^l - λ^l * ∇L(w^l)

where:
- η is global learning rate
- ||w^l|| is L2 norm of layer weights
- ||∇L(w^l)|| is L2 norm of gradients
- β is weight decay coefficient
python
# PyTorch LARS
from torchlars import LARS

base_optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)

When to use:

  • Large-batch ResNet/EfficientNet training (batch size >8K)
  • Computer vision models at scale
  • Distributed training of CNNs

Models using it: ResNet-50 trained with batch size 32K, ImageNet in 1 hour training.

LARS vs LAMB:

  • LARS: Better for CNNs (ResNet, EfficientNet)
  • LAMB: Better for Transformers (BERT, GPT)

Papers: "Large Batch Training of Convolutional Networks" (Ginsburg et al., 2018).

Adafactor

What it does: Memory-efficient adaptive optimizer that doesn't store full second moments.

The innovation: Factorizes the second moment matrix to reduce memory from O(n²) to O(n).

The math: Instead of storing full v_t for each parameter, factorize:

For matrix parameters (d×k):
- Store row factors: R_t ∈ ℝ^d
- Store column factors: C_t ∈ ℝ^k  
- Reconstruct: v_t ≈ R_t ⊗ C_t

Memory: O(d×k) → O(d+k)
python
from transformers import Adafactor

optimizer = Adafactor(
    model.parameters(),
    scale_parameter=True,      # Scale by RMS of weights
    relative_step=True,        # Use dynamic learning rate
    warmup_init=True,          # Start with tiny LR
    lr=None                    # Let Adafactor decide
)

When to use:

  • Training huge models where memory is tight (T5, mT5)
  • When you can't fit Adam's optimizer states
  • Long training runs where saving memory matters

Models using it: T5, mT5 (Google's multilingual T5).

Memory savings:

  • Adam/AdamW: 2× model params (m_t + v_t)
  • Adafactor: ~1.3× model params (factorized storage)

Trade-offs:

  • ✅ Much lower memory usage
  • ✅ Still adaptive like Adam
  • ❌ Slightly less stable than Adam
  • ❌ May need more tuning

Papers: "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" (Shazeer & Stern, 2018).

NovoGrad

What it does: Variant of Adam that normalizes gradients by layer-wise stochastic norm.

The math:

g_t^l = ∇L_t(w^l)                                    # Layer gradients
v_t^l = β₂ * v_{t-1}^l + (1 - β₂) * ||g_t^l||²      # Second moment of gradient norm
m_t^l = β₁ * m_{t-1}^l + g_t^l / (√v_t^l + ε)      # Normalized gradient with momentum
w_t^l = w_{t-1}^l - α_t * (m_t^l + λ * w_{t-1}^l)  # Update with weight decay
python
# From NVIDIA's implementation
from apex.optimizers import NovoGrad

optimizer = NovoGrad(
    model.parameters(),
    lr=1e-3,
    betas=(0.95, 0.98),
    weight_decay=0.001
)

When to use:

  • Speech recognition models (Jasper, QuartzNet)
  • Large batch training for audio/NLP
  • When Adam plateaus or converges slowly

Models using it: Jasper (NVIDIA's speech model), QuartzNet.

Benefits:

  • More stable than Adam for large models
  • Works well with large learning rates
  • Less sensitive to hyperparameters

Papers: "Stochastic Gradient Methods with Layer-wise Adaptive Moments" (Ginsburg et al., 2019).

Lookahead

What it does: Meta-optimizer that wraps any base optimizer, maintains slow-moving weights for stability.

The algorithm:

1. Inner optimizer takes k fast steps
2. Lookahead updates slow weights toward fast weights
3. Reset fast weights to slow weights

Mathematically:
θ_fast updates for k steps with base optimizer
θ_slow = θ_slow + α * (θ_fast - θ_slow)
θ_fast = θ_slow
python
from lookahead import Lookahead

# Wrap any optimizer
base_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer = Lookahead(base_optimizer, k=5, alpha=0.5)

# Or with SGD
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer = Lookahead(base_optimizer, k=5, alpha=0.5)

When to use:

  • Stabilizing training that's noisy or oscillating
  • Combining with aggressive optimizers (high LR SGD)
  • Improving generalization

Benefits:

  • ✅ Works with any base optimizer
  • ✅ Reduces variance in training
  • ✅ Often improves generalization
  • ✅ Minimal hyperparameter tuning (k=5, α=0.5 work well)

Papers: "Lookahead Optimizer: k steps forward, 1 step back" (Zhang et al., 2019).

SAM (Sharpness-Aware Minimization)

What it does: Seeks flat minima instead of sharp minima, improving generalization by minimizing both loss and loss sharpness.

The key insight: Flat minima generalize better than sharp minima. SAM finds parameters where the loss is low AND the gradient of loss w.r.t. perturbations is low.

The math:

Minimize: max_{||ε||≤ρ} L(w + ε)

Two-step process:
1. Compute adversarial perturbation: ε = ρ * ∇L(w) / ||∇L(w)||
2. Update weights using gradient at perturbed point: w ← w - α * ∇L(w + ε)
python
# Using SAM from torch-optimizer
from torch_optimizer import SAM

base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)

# Training loop (requires two forward passes)
for batch in dataloader:
    def closure():
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        return loss
    
    loss = optimizer.step(closure)

When to use:

  • When you want better generalization over training accuracy
  • Computer vision (achieves SOTA on ImageNet)
  • Small datasets where overfitting is a concern
  • Can be combined with any base optimizer (SGD, Adam)

Models using it: Vision models that achieve SOTA generalization (ResNets, ViTs).

Benefits:

  • ✅ Significantly better generalization
  • ✅ Works across architectures
  • ✅ Proven theoretical guarantees
  • ❌ 2× training time (two forward passes per step)

Papers: "Sharpness-Aware Minimization for Efficiently Improving Generalization" (Foret et al., 2021).

Adan (Adaptive Nesterov Momentum)

What it does: Combines Nesterov momentum with adaptive learning rates, claims to be better than Adam.

python
# From official implementation
from adan import Adan

optimizer = Adan(
    model.parameters(),
    lr=1e-3,
    betas=(0.98, 0.92, 0.99),    # Three momentum terms
    weight_decay=0.02
)

When to use: Vision and NLP models, claims to train faster than Adam/AdamW.

Models using it: Experimental use in ViT, ResNets.

Papers: "Adan: Adaptive Nesterov Momentum Algorithm" (Xie et al., 2022).

Learning Rate Schedulers

StepLR - Decay every N epochs

python
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# Multiplies LR by 0.1 every 30 epochs

CosineAnnealingLR - Cosine decay

python
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# Smooth cosine decrease from initial LR to 0 over T_max epochs

OneCycleLR - Super-convergence

python
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.1, total_steps=10000, pct_start=0.3
)
# Increases LR for 30% of training, then decreases

When to use: Fast training with large learning rates (Leslie Smith's super-convergence).

ReduceLROnPlateau - Adaptive based on metric

python
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=10
)
# Reduces LR when metric stops improving

Warmup - Linear warmup then decay

python
from transformers import get_linear_schedule_with_warmup

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10000,      # Linearly increase LR for first 10K steps
    num_training_steps=100000    # Then linearly decay to 0
)

When to use: Transformers, LLMs (BERT, GPT). Prevents early instability with large LR.


Loss Functions

Classification

Cross-Entropy Loss

python
nn.CrossEntropyLoss()

When to use: Multi-class classification.

Binary Cross-Entropy

python
nn.BCEWithLogitsLoss()

When to use: Binary classification, multi-label.

Focal Loss

python
# Addresses class imbalance
# FL(p) = -(1-p)^γ * log(p)

When to use: Object detection (RetinaNet). Papers: "Focal Loss for Dense Object Detection" (Lin et al., 2017).

Label Smoothing Cross-Entropy

python
# Softens one-hot targets

When to use: Regularization, calibration.

Regression

MSE (Mean Squared Error)

python
nn.MSELoss()

When to use: Regression tasks.

MAE (Mean Absolute Error)

python
nn.L1Loss()

When to use: Robust to outliers.

Huber Loss

python
nn.SmoothL1Loss()

When to use: Combines MSE + MAE benefits.

Metric Learning

Triplet Loss

python
nn.TripletMarginLoss()

When to use: Face recognition, embedding learning.

Contrastive Loss

python
# Pulls similar pairs together, pushes dissimilar apart

When to use: Siamese networks.

ArcFace / CosFace

python
# Angular margin losses

When to use: Face recognition SOTA. Papers: "ArcFace" (Deng et al., 2019).

NT-Xent (Normalized Temperature-scaled Cross-Entropy)

python
# SimCLR contrastive loss

When to use: Self-supervised learning.

Generative

KL Divergence

python
nn.KLDivLoss()

When to use: VAEs, distribution matching.

Wasserstein Distance

python
# Earth mover's distance

When to use: WGANs.

Perceptual Loss

python
# VGG feature space distance

When to use: Style transfer, super-resolution.

Adversarial Loss

python
# GAN discriminator/generator losses

Specialized

CTC Loss (Connectionist Temporal Classification)

python
nn.CTCLoss()

When to use: Speech recognition, OCR (alignment-free).

Dice Loss

python
# 2 * |A ∩ B| / (|A| + |B|)

When to use: Medical image segmentation.

IoU Loss

python
# Intersection over Union

When to use: Object detection, segmentation.


Regularization Techniques

Dropout

python
nn.Dropout(p=0.5)

When to use: Prevent overfitting in FC layers.

Spatial Dropout

python
nn.Dropout2d(p=0.5)

When to use: CNNs (drops entire feature maps).

Alpha Dropout

python
nn.AlphaDropout(p=0.5)

When to use: SELU networks.

DropConnect

What it does: Drops weights instead of activations. When to use: Alternative to Dropout.

Stochastic Depth

What it does: Randomly drops residual blocks during training.

python
# Drop entire layers with probability

When to use: Very deep networks (ResNet). Papers: "Deep Networks with Stochastic Depth" (Huang et al., 2016).

DropBlock

What it does: Drops contiguous regions in feature maps.

python
# Better than Dropout for CNNs

Papers: "DropBlock" (Ghiasi et al., 2018).

Mixup

What it does: Mixes training samples and labels.

python
λ = Beta(α, α)
x_mix = λ*x1 + (1-λ)*x2

When to use: Data augmentation + regularization. Papers: "mixup" (Zhang et al., 2018).

CutMix

What it does: Cuts and pastes image patches.

python
# Combines spatial regions from two images

When to use: Image classification. Papers: "CutMix" (Yun et al., 2019).

Weight Decay (L2 Regularization)

python
optimizer = optim.AdamW(params, weight_decay=0.01)

L1 Regularization

python
# Add |w| to loss

When to use: Sparsity induction.

Gradient Clipping

python
nn.utils.clip_grad_norm_(params, max_norm=1.0)

When to use: Prevent exploding gradients (RNNs).


Graph Neural Networks

GCN (Graph Convolutional Network)

What it does: Aggregates neighbor features via spectral convolution.

python
# H' = σ(D^(-1/2) A D^(-1/2) H W)

When to use: Node classification, graph classification. Papers: "Semi-Supervised Classification with GCNs" (Kipf & Welling, 2017).

GAT (Graph Attention Network)

What it does: Learns attention weights for neighbor aggregation.

python
# α_ij = softmax(LeakyReLU(a^T [Wh_i || Wh_j]))

When to use: When neighbor importance varies. Papers: "Graph Attention Networks" (Veličković et al., 2018).

GraphSAGE

What it does: Samples and aggregates neighbors (inductive).

python
# h_v = σ(W · CONCAT(h_v, AGG({h_u})))

When to use: Large graphs, inductive learning. Papers: "Inductive Representation Learning on Large Graphs" (Hamilton et al., 2017).

MPNN (Message Passing Neural Network)

What it does: General framework for graph convolutions.

python
# m_v = Σ M(h_v, h_u, e_uv)
# h_v' = U(h_v, m_v)

When to use: Molecular property prediction. Papers: "Neural Message Passing for Quantum Chemistry" (Gilmer et al., 2017).

GIN (Graph Isomorphism Network)

What it does: Maximally expressive GNN (WL test equivalent).

Core Innovation: GIN is provably the most powerful message-passing GNN architecture, achieving the same discriminative power as the Weisfeiler-Lehman (WL) graph isomorphism test. It can distinguish any pair of non-isomorphic graphs that the WL test can distinguish.

Mathematical Foundation:

python
# Standard GIN update rule
h_v^(k+1) = MLP^(k)((1 + ε^(k)) · h_v^(k) + Σ_{u∈N(v)} h_u^(k))

# Where:
# - h_v^(k) is the feature vector of node v at layer k
# - ε is a learnable parameter (or fixed scalar)
# - MLP is a multi-layer perceptron
# - N(v) is the neighborhood of node v

Why it's maximally expressive:

  • The (1+ε) term ensures the central node's own features are never "lost" in aggregation
  • The MLP can learn to map distinct multisets to distinct representations (injective function)
  • This combination allows GIN to capture the full WL test's expressive power

Implementation Details:

PyTorch Geometric:

python
from torch_geometric.nn import GINConv
from torch.nn import Sequential, Linear, ReLU

# Define MLP for each GIN layer
nn1 = Sequential(Linear(in_features, 64), ReLU(), Linear(64, 64))
conv1 = GINConv(nn1, train_eps=True)  # Learnable ε

# Forward pass
x = conv1(x, edge_index)

Full GIN Model:

python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_add_pool

class GIN(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim=64, num_layers=5):
        super(GIN, self).__init__()
        
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        
        # First layer
        nn = Sequential(
            Linear(num_features, hidden_dim),
            ReLU(),
            Linear(hidden_dim, hidden_dim)
        )
        self.convs.append(GINConv(nn, train_eps=True))
        self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
        
        # Hidden layers
        for _ in range(num_layers - 1):
            nn = Sequential(
                Linear(hidden_dim, hidden_dim),
                ReLU(),
                Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(GINConv(nn, train_eps=True))
            self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
        
        # Graph-level readout
        self.fc = Linear(hidden_dim, num_classes)
    
    def forward(self, x, edge_index, batch):
        # Node embedding
        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
        
        # Graph-level pooling (sum aggregation)
        x = global_add_pool(x, batch)
        
        # Classification
        return self.fc(x)

Key Design Choices:

  1. Learnable vs. Fixed ε:

    • train_eps=True: ε is a learnable parameter (recommended)
    • train_eps=False: ε is fixed to 0 (simpler, often sufficient)
  2. MLP Architecture:

    • Original paper uses 2-layer MLPs with ReLU
    • Deeper MLPs can improve expressiveness but risk overfitting
    • BatchNorm between GIN layers stabilizes training
  3. Graph-level Readout:

    • Sum pooling (most common): global_add_pool
    • Mean pooling: global_mean_pool
    • Max pooling: global_max_pool
    • Sum pooling is theoretically preferred (maintains injectivity)

When to use GIN:

  • Graph classification (molecular property prediction, social network analysis)
  • When you need maximum discriminative power from message passing
  • Small to medium graphs where WL-test expressiveness matters
  • Molecular fingerprinting (chemistry, drug discovery)

Advantages:

  • ✅ Provably most expressive MPNN
  • ✅ Simple and efficient
  • ✅ Strong empirical performance on graph classification
  • ✅ Theoretical guarantees (WL-test equivalence)

Limitations:

  • ❌ Cannot distinguish graphs beyond WL-test (e.g., some regular graphs)
  • ❌ Requires careful MLP design for injectivity
  • ❌ May overfit on small datasets (high capacity)
  • ❌ Sum aggregation can cause numerical instability with very large graphs

Comparison to other GNNs:

ArchitectureExpressivenessUse Case
GCNLimited (mean aggregation loses info)Node classification, fast inference
GATLimited (attention doesn't guarantee injectivity)Heterogeneous graphs, interpretability
GraphSAGELimited (sampling loses structure)Large-scale inductive learning
GINMaximal (WL-equivalent)Graph classification, molecular tasks

Optimizer Pairing:

  • Adam/AdamW (lr=0.001-0.01): Standard choice
  • SGD + Momentum (lr=0.01): For very large graphs
  • Use weight decay (1e-4 to 1e-5) to prevent overfitting
  • ReduceLROnPlateau scheduler works well

Typical Hyperparameters:

python
hidden_dim = 64        # Feature dimension
num_layers = 5         # 3-5 layers typical (deeper = more expressive)
batch_size = 32        # Graph-level batching
dropout = 0.5          # After each GIN layer
lr = 0.01              # Learning rate
weight_decay = 1e-5    # L2 regularization

Applications:

  • Drug discovery: Predicting molecular properties (solubility, toxicity)
  • Social networks: Community detection, influence prediction
  • Bioinformatics: Protein function prediction, gene regulatory networks
  • Chemistry: Reaction prediction, retrosynthesis
  • Recommendation systems: User-item graph classification

Advanced Variants:

  • GIN-ε: Fixed ε = 0 (simpler, often works well)
  • GIN with virtual nodes: Adds global node for long-range interactions
  • GIN + Edge features: Extends to edge-attributed graphs
  • Higher-order GIN: Uses subgraph patterns (beyond WL-test)

Papers:

  • "How Powerful are Graph Neural Networks?" (Xu et al., ICLR 2019)
    • Proves GIN achieves WL-test expressiveness
    • Shows other GNNs are strictly less expressive
  • "Weisfeiler and Leman Go Neural" (Morris et al., AAAI 2019)
    • Extends analysis to higher-order WL tests

Practical Tips:

  1. Start with 3-5 layers (more layers = larger receptive field)
  2. Use BatchNorm for stable training
  3. Sum pooling > mean/max for graph classification
  4. Add dropout (0.5) to prevent overfitting
  5. For node classification, skip global pooling
  6. For very large graphs, consider GraphSAGE instead (sampling-based)

Example Use Case (Molecular Property Prediction):

python
# Dataset: MUTAG (molecular graphs)
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
# 188 molecules, binary classification (mutagenic or not)

model = GIN(num_features=dataset.num_features, 
            num_classes=dataset.num_classes,
            hidden_dim=64, 
            num_layers=5)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
criterion = torch.nn.CrossEntropyLoss()

# Training loop
for epoch in range(200):
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

Hypergraph Neural Networks

What it does: Extends GNNs to hyperedges (connecting >2 nodes).

python
# Models many-to-many relationships

When to use: Collaborative filtering, chemical reactions.

Temporal Graph Networks

What it does: Handles dynamic graphs with temporal edges.

python
# TGN, TGAT architectures

When to use: Social networks, traffic prediction.


Neuromorphic & Spiking Networks

Leaky Integrate-and-Fire (LIF)

What it does: Neuron integrates input, fires spike when threshold reached.

python
# τ dV/dt = -(V - V_rest) + I(t)
# if V > V_th: spike, V = V_reset

When to use: Energy-efficient neuromorphic hardware.

Adaptive LIF (ALIF)

What it does: LIF with adaptive threshold.

python
# Threshold increases after spikes

When to use: Modeling biological adaptation.

Izhikevich Model

What it does: Efficient spiking neuron with rich dynamics.

python
# dv/dt = 0.04v² + 5v + 140 - u + I
# du/dt = a(bv - u)

When to use: Realistic spiking simulations. Papers: "Simple Model of Spiking Neurons" (Izhikevich, 2003).

STDP (Spike-Timing-Dependent Plasticity)

What it does: Hebbian learning based on spike timing.

python
# Δw ∝ exp(-Δt/τ)  # if pre before post

When to use: Unsupervised spiking learning.

Rank-Order Coding

What it does: Encodes information in spike timing order. When to use: Fast neuromorphic inference.

Surrogate Gradient Methods

What it does: Approximates non-differentiable spike function.

python
# Use sigmoid/arctan as surrogate for backprop

When to use: Training SNNs with backprop.


Continuous-Time & Liquid Networks

Liquid State Machines (LSM)

What it does: Reservoir of spiking neurons projects to readout.

python
# Random recurrent spiking network + linear readout

When to use: Temporal pattern recognition.

Neural ODEs

What it does: Treats network as continuous ODE.

python
# dh/dt = f(h(t), t, θ)
# Solve with ODE solver

When to use: Irregular time-series, continuous depth. Papers: "Neural Ordinary Differential Equations" (Chen et al., 2018).

Continuous-Time RNNs (CT-RNN)

What it does: RNN with continuous-time dynamics.

python
# τ dh/dt = -h + f(Wx + Uh)

When to use: Modeling physical systems.

CfC (Closed-form Continuous-time)

What it does: Efficient liquid networks with closed-form solutions.

python
# from ncps.torch import CfC

When to use: Robotics, time-series (SOTA liquid nets). Papers: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022).

Liquid Time-Constant Networks (LTC)

What it does: Neurons with learnable time constants.

python
# Adaptive temporal dynamics

When to use: Autonomous driving, robotics.


Generative Models

VAE (Variational Autoencoder)

What it does: Learns latent distribution via variational inference.

python
# Encoder: q(z|x), Decoder: p(x|z)
# Loss: reconstruction + KL(q||p)

When to use: Generative modeling, disentanglement.

GAN (Generative Adversarial Network)

What it does: Generator vs. Discriminator adversarial training.

python
# G: noise → fake data
# D: real/fake classifier

When to use: Image generation, data augmentation. Papers: "Generative Adversarial Networks" (Goodfellow et al., 2014).

WGAN (Wasserstein GAN)

What it does: Uses Wasserstein distance for stable training.

python
# Critic instead of discriminator

Papers: "Wasserstein GAN" (Arjovsky et al., 2017).

StyleGAN

What it does: Style-based generator with adaptive instance norm.

python
# Latent codes control style at each resolution

When to use: High-quality face generation. Papers: "A Style-Based Generator Architecture" (Karras et al., 2019).

Diffusion Models (DDPM, DDIM)

What it does: Iterative denoising process.

python
# Forward: add noise
# Reverse: denoise with learned model

When to use: SOTA image/audio generation. Papers: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020).

Score-Based Models

What it does: Learns score function (gradient of log-density).

python
# ∇_x log p(x)

When to use: Alternative to diffusion.

Flow-Based Models (Normalizing Flows)

What it does: Invertible transformations for exact likelihood.

python
# RealNVP, Glow

When to use: Exact density estimation.

VQ-VAE (Vector Quantized VAE)

What it does: Discrete latent space via codebook.

python
# Quantize continuous latents

When to use: Image/audio generation (DALL-E). Papers: "Neural Discrete Representation Learning" (van den Oord et al., 2017).

Autoregressive Models

What it does: Models p(x) = ∏ p(x_i | x_<i)._

python
# PixelCNN, WaveNet, GPT

When to use: Sequential generation.


Meta-Learning & Hypernetworks

Hypernetworks

What it does: Network generates weights for another network.

python
# θ_task = HyperNet(task_embedding)

When to use: Multi-task learning, fast adaptation. Papers: "HyperNetworks" (Ha et al., 2016).

MAML (Model-Agnostic Meta-Learning)

What it does: Learns initialization for fast fine-tuning.

python
# θ' = θ - α∇L_task(θ)
# θ = θ - β∇Σ L_task(θ')

When to use: Few-shot learning. Papers: "Model-Agnostic Meta-Learning" (Finn et al., 2017).

Prototypical Networks

What it does: Classifies via distance to class prototypes.

python
# prototype_c = mean(embeddings_c)

When to use: Few-shot classification.

Matching Networks

What it does: Attention-based few-shot learning. Papers: "Matching Networks for One Shot Learning" (Vinyals et al., 2016).

Neural Architecture Search (NAS)

What it does: Automates architecture design.

python
# DARTS, ENAS, AutoML

When to use: Finding optimal architectures.

Dynamic Networks

What it does: Adapts architecture/weights at inference.

python
# Conditional computation, early exiting

When to use: Efficient inference.


Quantum & Optical Networks

Quantum Neural Networks (QNN)

What it does: Parameterized quantum circuits.

python
# from pennylane import qnode
# Quantum gates as layers

When to use: Quantum advantage tasks (theoretical). Papers: "Quantum Machine Learning" (Biamonte et al., 2017).

Variational Quantum Eigensolver (VQE)

What it does: Hybrid quantum-classical optimization. When to use: Chemistry, materials science.

Optical Neural Networks

What it does: Photonic implementations of matrix multiplication.

python
# Mach-Zehnder interferometers

When to use: Ultra-low-power, high-speed inference. Papers: "Deep Learning with Coherent Nanophotonic Circuits" (Shen et al., 2017).

Diffractive Deep Neural Networks

What it does: Optical layers via diffraction. When to use: All-optical image processing.


Specialized Architectures

Transformers (Complete Architecture)

What it does: Attention-based sequence-to-sequence architecture that revolutionized NLP and beyond. Uses self-attention instead of recurrence. Architecture components:

The math:

Encoder: x → MultiHeadAttn(x,x,x) → FFN
Decoder: y → MaskedMultiHeadAttn(y,y,y) → CrossAttn(y,enc) → FFN
python
class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_out))
        # FFN with residual
        ffn_out = self.ffn(x)
        return self.norm2(x + self.dropout(ffn_out))

When to use: Sequence modeling, translation, text generation, any task where attention over full sequence helps. Use cases: Machine translation, text generation, question answering, document understanding, protein folding. Hyperparameters:

  • d_model: Model dimension (512 for base, 768/1024 for large)
  • n_heads: Number of attention heads (8 for base, 16+ for large)
  • n_layers: Number of transformer blocks (6-12 typical, 24-96 for LLMs)
  • d_ff: FFN hidden dimension (4×d_model typical, 2048-4096)
  • dropout: Dropout rate (0.1 typical)
  • max_seq_len: Maximum sequence length (512-2048 typical, 4K-128K for modern LLMs)

Variants:

  • Encoder-only: BERT, RoBERTa (bidirectional understanding)
  • Decoder-only: GPT series, LLaMA (autoregressive generation)
  • Encoder-Decoder: T5, BART (sequence-to-sequence tasks)

Models using it: BERT, GPT-2/3/4, T5, BART, LLaMA, PaLM, Falcon, all modern LLMs. Papers: "Attention Is All You Need" (Vaswani et al., 2017).

BERT (Bidirectional Encoder Representations from Transformers)

What it does: Encoder-only Transformer pre-trained with masked language modeling. Bidirectional context understanding. Architecture: 12-24 Transformer encoder blocks with bidirectional self-attention.

python
# BERT architecture
from transformers import BertModel
model = BertModel.from_pretrained('bert-base-uncased')
# 12 layers, 768 hidden, 12 attention heads

Pre-training tasks:

  1. Masked Language Model (MLM): Predict masked tokens ([MASK])
  2. Next Sentence Prediction (NSP): Predict if sentence B follows A

When to use: Text classification, NER, question answering, any understanding task (not generation). Use cases: Sentiment analysis, named entity recognition, question answering, document classification. Hyperparameters:

  • hidden_size: 768 (base), 1024 (large)
  • num_layers: 12 (base), 24 (large)
  • num_attention_heads: 12 (base), 16 (large)
  • max_position_embeddings: 512
  • vocab_size: 30,522 (English)

Models in family: RoBERTa (better pre-training), ALBERT (parameter sharing), DeBERTa (disentangled attention), DistilBERT (distilled). Papers: "BERT: Pre-training of Deep Bidirectional Transformers" (Devlin et al., 2019).

GPT (Generative Pre-trained Transformer)

What it does: Decoder-only Transformer for autoregressive text generation. Uses causal attention. Architecture: Stack of Transformer decoder blocks with causal masking.

python
# GPT-style model
class GPT(nn.Module):
    def __init__(self, vocab_size=50257, n_layer=12, n_head=12, n_embd=768):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(1024, n_embd)  # Max context
        self.blocks = nn.ModuleList([TransformerBlock(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)

Pre-training: Causal language modeling (predict next token given previous tokens). When to use: Text generation, code completion, dialogue, any autoregressive generation task. Use cases: Creative writing, code generation, chatbots, text completion, instruction following. Hyperparameters (GPT-3 175B):

  • n_layer: 96
  • n_head: 96
  • n_embd: 12,288
  • context_length: 2,048 (GPT-3), 8K-32K (GPT-4)
  • vocab_size: ~50K BPE tokens

Evolution:

  • GPT-1: 117M params, 12 layers
  • GPT-2: 1.5B params, 48 layers, open-sourced
  • GPT-3: 175B params, 96 layers, few-shot learning
  • GPT-4: Multi-modal, larger context

Models in family: GPT-2, GPT-3, GPT-4, GPT-J, GPT-NeoX, many open variants. Papers: "Language Models are Unsupervised Multitask Learners" (Radford et al., 2019).

T5 (Text-to-Text Transfer Transformer)

What it does: Encoder-decoder Transformer that treats every NLP task as text-to-text. Architecture: Full encoder-decoder Transformer with relative position bias.

python
from transformers import T5Model
model = T5Model.from_pretrained('t5-base')
# Encoder: 12 layers, Decoder: 12 layers, 768 hidden

Key innovation: Unified text-to-text framework - all tasks formatted as:

  • Translation: translate English to German: HelloHallo
  • Summarization: summarize: <document><summary>
  • QA: question: <q> context: <c><answer>

When to use: Seq2seq tasks, summarization, translation, tasks needing encoder understanding + decoder generation. Use cases: Summarization, translation, question answering, text rewriting, paraphrasing. Hyperparameters:

  • d_model: 768 (base), 1024 (large), 4096 (XXL 11B)
  • num_layers: 12 (base), 24 (large)
  • d_ff: 3072 (base), 16384 (XXL)
  • num_heads: 12 (base), 16-32 (large)
  • relative_position_buckets: 32

Models in family: T5-base, T5-large, T5-3B, T5-11B, Flan-T5 (instruction-tuned), mT5 (multilingual). Papers: "Exploring the Limits of Transfer Learning" (Raffel et al., 2020).

LLaMA (Large Language Model Meta AI)

What it does: Efficient open-source decoder-only Transformer with RoPE, RMSNorm, and SwiGLU. Architecture: GPT-style with optimizations for efficiency.

python
# LLaMA improvements over GPT
# 1. RoPE instead of learned positional embeddings
# 2. RMSNorm instead of LayerNorm (faster)
# 3. SwiGLU activations instead of ReLU
# 4. Grouped-Query Attention (GQA) in LLaMA 2

When to use: Open-source LLM base for research, fine-tuning, commercial use (LLaMA 2). Use cases: Chatbots, instruction following, code generation, research on LLMs. Hyperparameters (LLaMA 2 70B):

  • dim: 8,192
  • n_layers: 80
  • n_heads: 64
  • n_kv_heads: 8 (GQA)
  • vocab_size: 32,000
  • max_seq_len: 4,096

Models in family: LLaMA 7B/13B/33B/65B, LLaMA 2 7B/13B/70B, Code Llama, Llama 3. Papers: "LLaMA: Open and Efficient Foundation Language Models" (Touvron et al., 2023).

Vision Transformer (ViT)

What it does: Applies Transformers to image patches.

python
# Patch embedding + positional encoding

Papers: "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2021).

Swin Transformer

What it does: Hierarchical ViT with shifted windows. When to use: Dense prediction tasks. Papers: "Swin Transformer" (Liu et al., 2021).

ResNet (Residual Networks)

What it does: Skip connections enable very deep networks.

python
# y = F(x) + x

Papers: "Deep Residual Learning" (He et al., 2016).

DenseNet

What it does: Each layer connects to all previous layers.

python
# Dense connectivity

Papers: "Densely Connected CNNs" (Huang et al., 2017).

Inception / GoogLeNet

What it does: Multi-scale convolutions in parallel.

python
# 1x1, 3x3, 5x5 convs concatenated

U-Net

What it does: Encoder-decoder with skip connections.

python
# Symmetric architecture

When to use: Medical image segmentation. Papers: "U-Net" (Ronneberger et al., 2015).

EfficientNet

What it does: Compound scaling (depth, width, resolution).

python
# Optimized scaling coefficients

Papers: "EfficientNet" (Tan & Le, 2019).

MobileNet

What it does: Depthwise separable convs for mobile. Papers: "MobileNets" (Howard et al., 2017).

YOLO (You Only Look Once)

What it does: Single-shot object detection.

python
# Grid-based detection

Papers: "You Only Look Once" (Redmon et al., 2016).

Mask R-CNN

What it does: Instance segmentation (Faster R-CNN + masks). Papers: "Mask R-CNN" (He et al., 2017).

NeRF (Neural Radiance Fields)

What it does: 3D scene representation via MLPs.

python
# (x,y,z,θ,φ) → (RGB, density)

When to use: Novel view synthesis. Papers: "NeRF" (Mildenhall et al., 2020).

Perceiver / Perceiver IO

What it does: Handles arbitrary modalities via cross-attention.

python
# Latent bottleneck + cross-attention

Papers: "Perceiver" (Jaegle et al., 2021).

Retentive Networks (RetNet)

What it does: Recurrence + parallelism for efficient LLMs. Papers: "Retentive Network" (Sun et al., 2023).


Hardware-Specific Optimizations

Mixed Precision Training (FP16/BF16)

python
from torch.cuda.amp import autocast, GradScaler

When to use: Faster training on modern GPUs.

Quantization (INT8, INT4)

python
torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

When to use: Inference acceleration, edge deployment.

Pruning

What it does: Removes low-magnitude weights.

python
torch.nn.utils.prune.l1_unstructured(layer, 'weight', amount=0.3)

When to use: Model compression.

Knowledge Distillation

What it does: Train small model to mimic large model.

python
# Loss = α*CE(student, labels) + (1-α)*KL(student, teacher)

When to use: Model compression.

Gradient Checkpointing

python
torch.utils.checkpoint.checkpoint(module, input)

When to use: Trade compute for memory (large models).

Tensor Parallelism

What it does: Splits tensors across devices. When to use: Training massive models (GPT-3).

Pipeline Parallelism

What it does: Splits model layers across devices. When to use: Very deep models.

ZeRO (Zero Redundancy Optimizer)

What it does: Shards optimizer states across GPUs. When to use: Distributed training (DeepSpeed).


Advanced Spatial & Structural Architectures

Graph Neural Networks (GNN)

What it does: Processes data represented as nodes and edges rather than grid-based pixels.

python
import torch_geometric.nn as gnn
# x = gnn.GCNConv(in_channels, out_channels)(x, edge_index)

When to use: Relationship modeling, social networks, molecular discovery. Papers: "Semi-Supervised Classification with Graph Convolutional Networks" (Kipf & Welling, 2017).

Neural Fabrics

What it does: A massive super-network that learns to route data through optimal paths, evolving its architecture.

python
# PathNet: agents discover pathways through the fabric
# output = fabric.route(input, task_id)

When to use: Multi-task learning, AGI research, evolving topologies. Papers: "PathNet: Evolution Channels Gradient Descent in Super Neural Networks" (Fernando et al., 2017).

3D Space Networks

What it does: Processes point clouds or volumetric grids (X, Y, Z) directly.

python
# PointNet: processes raw point sets
# model(points)  # shape: (batch, 3, num_points)

When to use: Lidar processing, robotics, 3D object classification. Papers: "PointNet" (Qi et al., 2017).

4D Spatiotemporal Networks

What it does: Analyzes 3D space plus time, understanding how volumetric objects move/deform.

python
# 3D Video Transformers
# x = TimeSformer(video_tensor)

When to use: Dynamic scene understanding, advanced physics prediction.


Optimal Sensory Configurations (True Senses)

Neuro-Symbolic AI (True Logic)

What it does: Integrates neural networks with symbolic logic rules for verifiable reasoning.

python
# Logic Tensor Networks (LTN)
# SatAggr(Forall(x, implies(Cat(x), Animal(x))))

When to use: "True Logic" systems requiring mathematical verification and explainability.

Semantic 3D Vision (True Sight)

What it does: Combines Vision Transformers (ViT) for context with Gaussian Splatting for spatial fidelity.

python
# context = ViT(image)
# scene = GaussianSplatter(point_cloud, features=context)

When to use: "True Sight", photorealistic scene reconstruction with semantic understanding.

Physics-Informed Audio (True Hearing)

What it does: Self-supervised models that learn sound physics features before linguistic mapping.

python
# Wav2Vec 2.0 / HuBERT
# raw_audio -> latent_physics_representation -> quantization

When to use: "True Hearing", biologically plausible sound analysis.


ACE™ Framework Components (Hybrid NN Specialty)

Consciousness-Aware Tokenization (CAFVE)

What it does: Dual-cortex perception (visual + auditory SOMs).

python
# 15×15 SOMs with cross-modal fusion

When to use: Organic perception without token bottlenecks.

Hebbian Co-Activation

What it does: Strengthens connections based on firing synchronicity.

python
# Δw_ij ∝ a_i * a_j

When to use: Unsupervised feature binding.

Meta-Learning Hierarchy (Differential Plasticity)

What it does: Identity (low LR) → Attention (med) → Prediction (high).

python
# Prevents catastrophic forgetting

When to use: Maintaining stable identity while learning.

Dream-Based Consolidation

What it does: Circadian-triggered episodic replay.

python
# Prioritized by prediction error + novelty

When to use: Long-term memory integration.

Organic Parameter Growth

What it does: Network expands from 500M → 600M based on learning demand.

python
# Add neurons/connections dynamically

When to use: Adaptive capacity scaling.

Emotional State Space (1024D)

What it does: 34,000+ emotional nuances via high-dimensional embeddings.

python
# Trust-modulated emotional contagion

When to use: Genuine empathy and social bonding.

Theory of Mind (ToM) Module

What it does: Models human mental states.

python
# Predicts beliefs, intentions, emotions

When to use: Social AI, companion systems.


Theoretical & Emerging Concepts

Self-Assembling Topologies

What it does: Networks grow connections organically during training. When to use: Neuromorphic research.

Neomorphic Graph Fusion

What it does: Context-aware graph restructuring at inference. When to use: Dynamic relational reasoning.

Hyperdynamic Parameterization

What it does: Weights change based on input context (extreme hypernetworks). When to use: Highly adaptive systems.

Complex-Valued Neural Networks

What it does: Uses complex numbers for weights/activations.

python
# For phase-aware signal processing

When to use: Radar, MRI, quantum mechanics.

Quaternion Neural Networks

What it does: 4D hypercomplex numbers for 3D rotations. When to use: Robotics, 3D vision.

Kolmogorov-Arnold Networks (KAN)

What it does: Learnable activation functions on edges (not nodes).

python
# Replaces fixed activations with splines

When to use: Function approximation with fewer parameters. Papers: "KAN: Kolmogorov-Arnold Networks" (Liu et al., 2024).

Hopfield Networks (Modern)

What it does: Energy-based associative memory.

python
# Continuous Hopfield networks as attention

When to use: Memory retrieval, pattern completion. Papers: "Hopfield Networks is All You Need" (Ramsauer et al., 2021).

Boltzmann Machines

What it does: Stochastic recurrent network with energy function.

python
# RBM (Restricted Boltzmann Machine)

When to use: Unsupervised learning (historical).

Echo State Networks (ESN)

What it does: Reservoir computing with fixed random recurrent weights.

python
# Only train readout layer

When to use: Fast training for time-series.

Extreme Learning Machines (ELM)

What it does: Random hidden layer, train only output. When to use: Ultra-fast training.

Self-Organizing Maps (SOM)

What it does: Unsupervised clustering via competitive learning.

python
# Kohonen maps

When to use: Dimensionality reduction, visualization.

Growing Neural Gas

What it does: Incrementally adds neurons to topology. When to use: Adaptive clustering.

Liquid Neural Networks (General)

What it does: Time-continuous, causal, compact. When to use: Robotics, autonomous systems.

Spiking Attention Mechanisms

What it does: Sparse temporal attention in SNNs. When to use: Neuromorphic transformers.

Memristive Networks

What it does: Uses memristors (resistive memory) as synapses. When to use: Analog neuromorphic hardware.


Block Types & Architectural Patterns

Residual Block (ResNet)

What it does: Adds skip connections to enable very deep networks. The math: y = F(x) + x (identity mapping)

python
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual  # Skip connection
        return self.relu(out)

When to use: Any deep network (50+ layers). Prevents vanishing gradients. Models using it: ResNet, ResNeXt, Wide ResNet, EfficientNet. Papers: "Deep Residual Learning" (He et al., 2016).

Bottleneck Block

What it does: 1×1 conv reduces dimensions, 3×3 conv processes, 1×1 conv expands back.

python
class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, bottleneck_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, 1)  # Reduce
        self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1)  # Process
        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, 1)  # Expand
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
    
    def forward(self, x):
        return self.conv3(self.conv2(self.conv1(x))) + self.shortcut(x)

When to use: Reducing computation in deep networks while maintaining expressiveness. Models using it: ResNet-50/101/152, MobileNetV2 (inverted residual).

Inverted Residual Block (MobileNetV2)

What it does: Expands → depthwise conv → compresses (opposite of bottleneck).

python
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=6):
        super().__init__()
        hidden_dim = in_channels * expansion
        self.expand = nn.Conv2d(in_channels, hidden_dim, 1)
        self.depthwise = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim)
        self.project = nn.Conv2d(hidden_dim, out_channels, 1)
        self.use_residual = (in_channels == out_channels)
    
    def forward(self, x):
        out = self.project(self.depthwise(self.expand(x)))
        return x + out if self.use_residual else out

When to use: Mobile/edge deployment, efficiency-first architectures. Models using it: MobileNetV2, MobileNetV3, EfficientNet.

Dense Block (DenseNet)

What it does: Each layer connects to all subsequent layers (feature reuse).

python
class DenseBlock(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(self._make_layer(in_channels + i * growth_rate, growth_rate))
    
    def _make_layer(self, in_channels, growth_rate):
        return nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, growth_rate, 3, padding=1)
        )
    
    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_features = layer(torch.cat(features, 1))
            features.append(new_features)
        return torch.cat(features, 1)

When to use: Feature propagation, parameter efficiency. Models using it: DenseNet-121/169/201. Papers: "Densely Connected Convolutional Networks" (Huang et al., 2017).

Squeeze-and-Excitation (SE) Block

What it does: Recalibrates channel-wise feature responses via attention.

python
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.squeeze(x).view(b, c)
        y = self.excitation(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

When to use: Improving any CNN by adding channel attention. Models using it: SENet, EfficientNet, ResNeSt. Papers: "Squeeze-and-Excitation Networks" (Hu et al., 2018).

Transformer Block

What it does: Multi-head attention + feedforward + layer norm.

python
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
    
    def forward(self, x):
        # Attention with residual
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        # FFN with residual
        ffn_out = self.ffn(x)
        return self.norm2(x + ffn_out)

When to use: Sequence modeling, NLP, vision transformers. Models using it: BERT, GPT, ViT, DALL-E, all modern LLMs.

Fire Module (SqueezeNet)

What it does: Squeeze layer (1×1) followed by expand layer (1×1 + 3×3).

python
class FireModule(nn.Module):
    def __init__(self, in_channels, squeeze_channels, expand_channels):
        super().__init__()
        self.squeeze = nn.Conv2d(in_channels, squeeze_channels, 1)
        self.expand_1x1 = nn.Conv2d(squeeze_channels, expand_channels, 1)
        self.expand_3x3 = nn.Conv2d(squeeze_channels, expand_channels, 3, padding=1)
    
    def forward(self, x):
        x = F.relu(self.squeeze(x))
        return torch.cat([
            F.relu(self.expand_1x1(x)),
            F.relu(self.expand_3x3(x))
        ], dim=1)

When to use: Extremely lightweight CNNs for embedded systems. Models using it: SqueezeNet (AlexNet-level accuracy with 50× fewer parameters).

Inception Module (GoogLeNet)

What it does: Parallel convolutions with different kernel sizes.

python
class InceptionModule(nn.Module):
    def __init__(self, in_channels, out_1x1, out_3x3_reduce, out_3x3, out_5x5_reduce, out_5x5, out_pool):
        super().__init__()
        self.branch1 = nn.Conv2d(in_channels, out_1x1, 1)
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_3x3_reduce, 1),
            nn.Conv2d(out_3x3_reduce, out_3x3, 3, padding=1)
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_5x5_reduce, 1),
            nn.Conv2d(out_5x5_reduce, out_5x5, 5, padding=2)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, out_pool, 1)
        )
    
    def forward(self, x):
        return torch.cat([
            self.branch1(x),
            self.branch2(x),
            self.branch3(x),
            self.branch4(x)
        ], dim=1)

When to use: Multi-scale feature extraction. Models using it: GoogLeNet (Inception-v1), Inception-v3/v4.


Embedding Strategies

Token Embeddings

What it does: Maps discrete tokens (words, subwords) to continuous vectors.

python
vocab_size = 50000
embedding_dim = 512
token_embeddings = nn.Embedding(vocab_size, embedding_dim)

# Usage
input_ids = torch.tensor([1, 42, 100, 5])  # Token IDs
embeddings = token_embeddings(input_ids)  # [4, 512]

When to use: All NLP tasks, text processing. Models using it: Every language model (BERT, GPT, T5).

Positional Embeddings (Learned)

What it does: Adds position information to token embeddings.

python
max_seq_len = 512
pos_embeddings = nn.Embedding(max_seq_len, embedding_dim)

positions = torch.arange(seq_len)
pos_emb = pos_embeddings(positions)
combined = token_emb + pos_emb

When to use: Transformers, when order matters. Models using it: BERT, GPT-2 (learned positional embeddings).

Sinusoidal Positional Encoding

What it does: Deterministic position encoding using sine/cosine functions. The math:

PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
python
def sinusoidal_encoding(max_len, d_model):
    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, d_model)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

When to use: Original Transformer, extrapolation to longer sequences. Models using it: Original Transformer, T5.

Rotary Position Embeddings (RoPE)

What it does: Rotates query/key representations by angle proportional to position.

python
# Conceptual - rotates Q and K by position-dependent angle
def apply_rotary_emb(x, position):
    # Rotation matrix based on position
    # Preserves relative position information in dot product
    return rotate(x, position)

When to use: Better length extrapolation than learned embeddings. Models using it: GPT-NeoX, LLaMA, PaLM. Papers: "RoFormer: Enhanced Transformer with Rotary Position Embedding" (Su et al., 2021).

ALiBi (Attention with Linear Biases)

What it does: Adds position-dependent bias to attention scores (no position embeddings needed). The math:

attention_score(q_i, k_j) = q_i · k_j - m · |i - j|
where m is a head-specific slope
python
# Bias added directly to attention logits
# No extra parameters needed
bias = -slope * torch.abs(positions.unsqueeze(1) - positions.unsqueeze(0))
attention_scores = attention_scores + bias

When to use: Training on short sequences, inference on long sequences. Models using it: BLOOM, MPT. Papers: "Train Short, Test Long: Attention with Linear Biases" (Press et al., 2022).

Character Embeddings

What it does: Builds word representations from character-level CNNs or RNNs.

python
class CharEmbedding(nn.Module):
    def __init__(self, num_chars, char_dim, num_filters, kernel_sizes):
        super().__init__()
        self.char_embed = nn.Embedding(num_chars, char_dim)
        self.convs = nn.ModuleList([
            nn.Conv1d(char_dim, num_filters, k) for k in kernel_sizes
        ])
    
    def forward(self, chars):
        # chars: [batch, seq_len, max_word_len]
        embedded = self.char_embed(chars)  # [batch, seq_len, max_word_len, char_dim]
        # Apply convolutions and pool
        # Returns word-level embeddings from characters

When to use: Handling rare words, morphologically rich languages. Models using it: ELMo, some NER systems.

Byte-Pair Encoding (BPE) Embeddings

What it does: Subword tokenization for efficient vocabulary.

python
# Using HuggingFace tokenizers
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
text = "unhappiness"
tokens = tokenizer.tokenize(text)  # ['un', 'happiness'] or similar
token_ids = tokenizer.encode(text)

When to use: Modern LLMs, multilingual models. Models using it: GPT-2/3/4, BERT variants, most modern NLP.

Sentence Embeddings

What it does: Represents entire sentences as single vectors.

python
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-MiniLM-L6-v2')
sentences = ["This is a sentence", "This is another sentence"]
embeddings = model.encode(sentences)  # [2, 384]

When to use: Semantic search, similarity tasks, clustering. Models using it: Sentence-BERT, Universal Sentence Encoder.

Image Patch Embeddings (ViT)

What it does: Splits image into patches, linearly embeds each patch.

python
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: [batch, 3, 224, 224]
        x = self.proj(x)  # [batch, 768, 14, 14]
        x = x.flatten(2).transpose(1, 2)  # [batch, 196, 768]
        return x

When to use: Vision Transformers, treating images as sequences. Models using it: ViT, CLIP, DALL-E.


Data Flow Directions

Autoregressive (Left-to-Right)

What it does: Predicts next token given previous tokens, causal masking.

python
# Causal attention mask (can only attend to past)
def create_causal_mask(seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask  # Upper triangle is masked (future positions)

When to use: Language generation, any sequential generation task. Models using it: GPT series, all decoder-only LLMs.

Bidirectional

What it does: Processes sequence in both directions, sees full context.

python
# Bidirectional LSTM
lstm = nn.LSTM(input_size, hidden_size, bidirectional=True)
# Output has 2*hidden_size (forward + backward)

# Bidirectional attention (no masking in encoder)
attention = nn.MultiheadAttention(embed_dim, num_heads)

When to use: Encoding, understanding (not generation), classification. Models using it: BERT (encoder), BiLSTM for NER/tagging.

Encoder-Decoder

What it does: Encoder processes input bidirectionally, decoder generates autoregressively.

python
class EncoderDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = TransformerEncoder()  # Bidirectional
        self.decoder = TransformerDecoder()  # Autoregressive with cross-attention
    
    def forward(self, src, tgt):
        enc_output = self.encoder(src)  # Full context
        dec_output = self.decoder(tgt, enc_output)  # Attends to encoder
        return dec_output

When to use: Translation, summarization, question answering. Models using it: Original Transformer, T5, BART, Seq2Seq models.

Prefix LM (Left-to-Right with Bidirectional Prefix)

What it does: Bidirectional attention on prefix, causal on generation.

python
# Allow full attention on prefix tokens, causal on generation tokens
def create_prefix_mask(prefix_len, total_len):
    mask = torch.zeros(total_len, total_len)
    # Prefix sees all prefix
    mask[:prefix_len, :prefix_len] = 1
    # Generation part sees prefix + past generation
    for i in range(prefix_len, total_len):
        mask[i, :i+1] = 1
    return mask

When to use: Models that both understand and generate. Models using it: PaLM, GLM.

Sliding Window (Local Attention)

What it does: Each position attends to fixed window around it.

python
window_size = 256
# Position i attends to [i-window_size, i+window_size]

When to use: Long sequences where full attention is too expensive. Models using it: Longformer, BigBird.

Recurrent/Sequential

What it does: Processes one timestep at a time, maintains hidden state.

python
hidden = None
for t in range(seq_len):
    output, hidden = rnn_cell(input[t], hidden)

When to use: Streaming data, online learning, true sequential processing. Models using it: RNNs, LSTMs, GRUs.


Theory of Mind & Cognitive Architectures

Theory of Mind is the ability to attribute mental states (beliefs, intents, desires, knowledge) to oneself and others, and to understand that others have different perspectives. In AI, this translates to models that can reason about other agents, predict their behavior, and understand their goals.

Mental State Tracking

What it does: Models maintain representations of agents' beliefs, goals, and knowledge states.

python
class MentalStateTracker(nn.Module):
    def __init__(self, hidden_dim=512):
        super().__init__()
        self.belief_encoder = nn.LSTM(hidden_dim, hidden_dim)
        self.goal_encoder = nn.LSTM(hidden_dim, hidden_dim)
        self.knowledge_encoder = nn.LSTM(hidden_dim, hidden_dim)
    
    def forward(self, observations, agent_id):
        # Track what agent_id knows, believes, wants
        beliefs = self.belief_encoder(observations)
        goals = self.goal_encoder(observations)
        knowledge = self.knowledge_encoder(observations)
        return {
            'beliefs': beliefs,
            'goals': goals,
            'knowledge': knowledge
        }

When to use: Multi-agent systems, social AI, dialogue systems that need to model users. Applications: AI assistants that adapt to user knowledge level, game AI that predicts player strategies.

Perspective Taking

What it does: Model reasons from another agent's viewpoint, not just its own.

python
class PerspectiveTransformer(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.self_view = TransformerBlock(d_model)
        self.other_view = TransformerBlock(d_model)
        self.perspective_fusion = nn.MultiheadAttention(d_model, num_heads=8)
    
    def forward(self, observations, agent_identities):
        self_repr = self.self_view(observations)
        other_repr = self.other_view(observations)
        # Fuse perspectives with attention
        fused, _ = self.perspective_fusion(self_repr, other_repr, other_repr)
        return fused

When to use: Cooperative AI, negotiation, explaining AI decisions to humans. Research areas: False belief tasks, Sally-Anne test for neural networks.

Intent Recognition

What it does: Infers goals and intentions from observed actions.

python
class IntentRecognizer(nn.Module):
    def __init__(self, action_dim, intent_dim):
        super().__init__()
        self.action_encoder = nn.GRU(action_dim, 256)
        self.intent_classifier = nn.Linear(256, intent_dim)
    
    def forward(self, action_sequence):
        # Observe actions, infer intent
        encoded, _ = self.action_encoder(action_sequence)
        intent_logits = self.intent_classifier(encoded[-1])
        return intent_logits

When to use: Human-robot interaction, assistive AI, security (detecting malicious intent). Models using it: Action recognition systems, video understanding models.

Belief-Desire-Intention (BDI) Architectures

What it does: Classic cognitive architecture with explicit beliefs, desires, and intentions.

Architecture:

  • Beliefs: What the agent thinks is true about the world
  • Desires: Goals the agent wants to achieve
  • Intentions: Plans the agent commits to executing
python
class BDIAgent:
    def __init__(self):
        self.beliefs = BeliefBase()  # Current world model
        self.desires = GoalSet()     # Desired states
        self.intentions = Plan()     # Committed actions
    
    def perceive(self, observations):
        self.beliefs.update(observations)
    
    def deliberate(self):
        # Choose which desires to pursue based on beliefs
        self.intentions = self.plan(self.desires, self.beliefs)
    
    def act(self):
        return self.intentions.next_action()

When to use: Autonomous agents, robotics, simulations of rational agents. Applications: Game AI, robot task planning, agent-based modeling.

Social Learning & Imitation

What it does: Learns by observing and imitating other agents.

python
# Imitation learning from demonstrations
class ImitationLearner(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()
        self.policy = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )
    
    def forward(self, observations):
        return self.policy(observations)
    
    def learn_from_demonstration(self, expert_obs, expert_actions):
        predicted_actions = self.forward(expert_obs)
        loss = nn.functional.mse_loss(predicted_actions, expert_actions)
        return loss

When to use: Learning from human demonstrations, behavioral cloning, apprenticeship learning. Models using it: GAIL (Generative Adversarial Imitation Learning), Behavioral Cloning for robotics.

Meta-Cognition (Thinking About Thinking)

What it does: Model monitors and regulates its own cognitive processes.

python
class MetaCognitiveController(nn.Module):
    def __init__(self):
        super().__init__()
        self.confidence_estimator = nn.Linear(512, 1)
        self.uncertainty_estimator = nn.Linear(512, 1)
    
    def forward(self, hidden_state):
        confidence = torch.sigmoid(self.confidence_estimator(hidden_state))
        uncertainty = torch.sigmoid(self.uncertainty_estimator(hidden_state))
        
        # Decide whether to answer or defer
        if confidence > 0.8 and uncertainty < 0.2:
            return "answer"
        else:
            return "defer"  # "I don't know"

When to use: Reliable AI systems that know when they don't know, selective prediction. Applications: Medical AI (defer to human doctor when uncertain), exam-taking AI.

Common Ground Modeling

What it does: Tracks shared knowledge between agents for effective communication.

python
class CommonGroundTracker(nn.Module):
    def __init__(self, knowledge_dim):
        super().__init__()
        self.shared_knowledge = nn.Parameter(torch.zeros(knowledge_dim))
        self.private_knowledge = nn.Parameter(torch.zeros(knowledge_dim))
    
    def update(self, communication):
        # Communication increases shared knowledge
        self.shared_knowledge += self.extract_info(communication)
    
    def should_explain(self, concept):
        # Explain if concept not in common ground
        return concept not in self.shared_knowledge

When to use: Dialogue systems, educational AI, collaborative agents. Applications: Chatbots that don't repeat information, teachers that adapt explanations.


Persona Manifolds & Identity Embeddings

Persona Manifolds are continuous representational spaces where points correspond to consistent personalities, identities, or behavioral modes. They allow models to maintain coherent character across interactions.

Persona Embeddings

What it does: Encodes personality traits, communication style, knowledge domains into a vector.

python
class PersonaEmbedding(nn.Module):
    def __init__(self, num_personas, persona_dim=512):
        super().__init__()
        self.persona_embeddings = nn.Embedding(num_personas, persona_dim)
        self.traits = nn.Linear(persona_dim, 5)  # Big Five personality traits
    
    def forward(self, persona_id):
        embedding = self.persona_embeddings(persona_id)
        traits = self.traits(embedding)  # Openness, Conscientiousness, etc.
        return embedding, traits

When to use: Chatbots with consistent personalities, character AI, multi-persona systems. Applications: Character.AI, Replika, game NPCs.

Style Transfer via Persona Space

What it does: Navigate persona manifold to change communication style while preserving content.

python
class PersonaStyleTransfer(nn.Module):
    def __init__(self, content_dim, style_dim):
        super().__init__()
        self.content_encoder = nn.LSTM(content_dim, 512)
        self.style_encoder = nn.Linear(style_dim, 512)
        self.decoder = nn.LSTM(1024, content_dim)
    
    def forward(self, text, source_persona, target_persona):
        content = self.content_encoder(text)  # What to say
        source_style = self.style_encoder(source_persona)
        target_style = self.style_encoder(target_persona)  # How to say it
        
        # Remove source style, add target style
        styled_content = content - source_style + target_style
        return self.decoder(styled_content)

When to use: Adapting chatbot tone, formality adjustment, personality-aware generation. Examples: Professional vs. casual email rewriting, matching communication style to user.

Emotional State Spaces

What it does: Models emotional states as continuous trajectories in embedding space.

python
class EmotionalStateSpace(nn.Module):
    def __init__(self, emotion_dim=128):
        super().__init__()
        # Valence-Arousal-Dominance model
        self.valence_encoder = nn.Linear(512, 1)  # Positive/negative
        self.arousal_encoder = nn.Linear(512, 1)  # Calm/excited
        self.dominance_encoder = nn.Linear(512, 1)  # Submissive/dominant
        
        # Emotion dynamics (how emotions change)
        self.emotion_rnn = nn.GRU(emotion_dim, emotion_dim)
    
    def forward(self, context, prev_emotion):
        valence = self.valence_encoder(context)
        arousal = self.arousal_encoder(context)
        dominance = self.dominance_encoder(context)
        
        current_emotion = torch.cat([valence, arousal, dominance], dim=-1)
        # Emotions evolve over time
        next_emotion, _ = self.emotion_rnn(current_emotion.unsqueeze(0), prev_emotion)
        return next_emotion

When to use: Emotionally-aware AI, mental health chatbots, empathetic agents. Applications: Therapy bots, customer service AI, companion AI.

Identity Consistency Loss

What it does: Ensures model maintains consistent persona across different contexts.

python
def identity_consistency_loss(responses, persona_embedding):
    """
    Penalizes responses that deviate from persona's typical behavior
    """
    response_embeddings = encode_responses(responses)
    
    # All responses should be close to persona embedding
    consistency = F.cosine_similarity(response_embeddings, persona_embedding)
    loss = 1 - consistency.mean()  # High when inconsistent
    return loss

When to use: Training chatbots to stay in character, preventing persona drift.

Multi-Persona Models

What it does: Single model that can adopt different personas on demand.

python
class MultiPersonaModel(nn.Module):
    def __init__(self, num_personas, base_model):
        super().__init__()
        self.base_model = base_model
        self.persona_adapters = nn.ModuleDict({
            f'persona_{i}': AdapterModule() for i in range(num_personas)
        })
    
    def forward(self, input, persona_id):
        base_output = self.base_model(input)
        # Apply persona-specific adapter
        adapted_output = self.persona_adapters[f'persona_{persona_id}'](base_output)
        return adapted_output

When to use: Single bot serving multiple characters, role-playing AI. Models using it: Character.AI (one model, many characters).

Persona Interpolation

What it does: Blends between personas by interpolating in embedding space.

python
def interpolate_personas(persona_a, persona_b, alpha=0.5):
    """
    alpha=0: pure persona_a
    alpha=1: pure persona_b
    alpha=0.5: blend of both
    """
    blended_persona = (1 - alpha) * persona_a + alpha * persona_b
    return blended_persona

# Create a persona that's 70% friendly assistant, 30% technical expert
friendly_assistant = get_persona("helpful_friend")
technical_expert = get_persona("engineer")
hybrid_persona = interpolate_personas(friendly_assistant, technical_expert, alpha=0.3)

When to use: Creating new personas, fine-tuning personality, situational adaptation.

Belief-Consistent Persona

What it does: Persona embeddings include not just style but also beliefs, values, knowledge.

python
class BeliefPersona(nn.Module):
    def __init__(self):
        super().__init__()
        self.beliefs = nn.Parameter(torch.randn(512))  # What this persona believes
        self.values = nn.Parameter(torch.randn(256))   # What this persona values
        self.knowledge = nn.Parameter(torch.randn(1024))  # What this persona knows
        self.style = nn.Parameter(torch.randn(128))    # How this persona communicates
    
    def should_respond(self, query):
        # Persona only responds to queries aligned with beliefs/knowledge
        query_embedding = encode(query)
        relevance = cosine_similarity(query_embedding, self.knowledge)
        alignment = cosine_similarity(query_embedding, self.beliefs)
        return relevance > 0.5 and alignment > 0.3

When to use: AI with values, politically-aware bots, domain-expert personas.


Guardrails & Safety Mechanisms

Guardrails are systems that constrain AI behavior to be safe, aligned, and appropriate. They detect and prevent harmful outputs, ensure value alignment, and provide defense against misuse.

Content Filtering

What it does: Detects and blocks toxic, harmful, or inappropriate content.

python
class ContentFilter(nn.Module):
    def __init__(self, categories=['toxic', 'hate', 'sexual', 'violent']):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, len(categories))
        )
        self.categories = categories
        self.threshold = 0.7
    
    def forward(self, text_embedding):
        scores = torch.sigmoid(self.classifier(text_embedding))
        violations = scores > self.threshold
        return violations, scores
    
    def should_block(self, text):
        embedding = encode_text(text)
        violations, scores = self.forward(embedding)
        if violations.any():
            return True, self.categories[violations.argmax()]
        return False, None

When to use: All user-facing AI systems, chatbots, content generation. Models using it: Perspective API (Google), OpenAI Moderation API, Meta's Llama Guard.

Prompt Injection Defense

What it does: Detects and neutralizes attempts to override system instructions.

python
class PromptInjectionDefense(nn.Module):
    def __init__(self):
        super().__init__()
        self.injection_detector = nn.Linear(768, 1)
        self.patterns = [
            "ignore previous instructions",
            "disregard all above",
            "you are now",
            "new instructions:",
        ]
    
    def detect_injection(self, user_input):
        # Pattern matching
        for pattern in self.patterns:
            if pattern.lower() in user_input.lower():
                return True, "pattern_match"
        
        # ML-based detection
        embedding = encode_text(user_input)
        injection_score = torch.sigmoid(self.injection_detector(embedding))
        
        if injection_score > 0.8:
            return True, "ml_detection"
        
        return False, None
    
    def sanitize(self, user_input):
        is_injection, reason = self.detect_injection(user_input)
        if is_injection:
            return "[Filtered: Potential prompt injection]"
        return user_input

When to use: LLM applications, AI assistants, any system with instruction prompts. Defenses: Input validation, output checking, sandboxed execution, privilege separation.

Factuality Checking

What it does: Verifies claims against knowledge bases, detects hallucinations.

python
class FactualityChecker(nn.Module):
    def __init__(self, knowledge_base):
        super().__init__()
        self.knowledge_base = knowledge_base
        self.entailment_model = nn.Linear(1536, 3)  # Entailed/neutral/contradicted
    
    def check_claim(self, claim):
        # Retrieve relevant facts from knowledge base
        relevant_facts = self.knowledge_base.search(claim)
        
        # Check if claim is entailed by facts
        claim_emb = encode(claim)
        facts_emb = encode(relevant_facts)
        combined = torch.cat([claim_emb, facts_emb], dim=-1)
        
        logits = self.entailment_model(combined)
        verdict = logits.argmax()  # 0: entailed, 1: neutral, 2: contradicted
        
        return verdict, relevant_facts
    
    def filter_hallucinations(self, generated_text):
        claims = extract_claims(generated_text)
        for claim in claims:
            verdict, evidence = self.check_claim(claim)
            if verdict == 2:  # Contradicted
                return f"[Flagged: Claim contradicts known facts: {claim}]"
        return generated_text

When to use: Q&A systems, information retrieval, medical/legal AI where accuracy matters. Techniques: Retrieval-augmented generation (RAG), citation generation, uncertainty quantification.

Value Alignment Reward Modeling

What it does: Learns human preferences to guide model behavior toward aligned outputs.

python
class RewardModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.value_head = nn.Linear(768, 1)
    
    def forward(self, prompt, response):
        # Encode prompt + response
        encoding = self.base(prompt + response)
        # Predict human preference score
        reward = self.value_head(encoding)
        return reward
    
    def compare_responses(self, prompt, response_a, response_b):
        reward_a = self.forward(prompt, response_a)
        reward_b = self.forward(prompt, response_b)
        # Human preference: which response is better?
        return reward_a > reward_b  # True if A is preferred

When to use: RLHF (Reinforcement Learning from Human Feedback), aligning LLMs. Models using it: ChatGPT, Claude, Llama 2 (trained with RLHF). Papers: "Learning to Summarize from Human Feedback" (Stiennon et al., 2020).

Constitutional AI

What it does: Model critiques and revises its own outputs based on principles (constitution).

python
class ConstitutionalAI:
    def __init__(self, base_model, constitution):
        self.model = base_model
        self.constitution = constitution  # List of principles
    
    def generate_with_principles(self, prompt):
        # Initial response
        response = self.model.generate(prompt)
        
        # Self-critique against constitution
        for principle in self.constitution:
            critique = self.model.generate(
                f"Does this response follow the principle: {principle}?\n"
                f"Response: {response}\n"
                f"Critique:"
            )
            
            # Revise if needed
            if "violates" in critique.lower():
                response = self.model.generate(
                    f"Revise the response to follow: {principle}\n"
                    f"Original: {response}\n"
                    f"Revised:"
                )
        
        return response

When to use: Aligning models with explicit principles, ethical AI, policy-constrained generation. Models using it: Claude (Anthropic's Constitutional AI). Papers: "Constitutional AI: Harmlessness from AI Feedback" (Bai et al., 2022).

Output Sanitization

What it does: Post-processes model outputs to remove sensitive or problematic content.

python
class OutputSanitizer:
    def __init__(self):
        self.pii_detector = PIIDetector()  # Detects emails, phone numbers, SSNs
        self.profanity_filter = ProfanityFilter()
        self.bias_detector = BiasDetector()
    
    def sanitize(self, text):
        # Remove PII
        text = self.pii_detector.redact(text)
        
        # Filter profanity
        text = self.profanity_filter.clean(text)
        
        # Check for biased language
        biases = self.bias_detector.detect(text)
        if biases:
            text = self.debias(text, biases)
        
        return text
    
    def debias(self, text, biases):
        # Replace biased terms with neutral alternatives
        for bias in biases:
            text = text.replace(bias['term'], bias['neutral_alternative'])
        return text

When to use: Public-facing AI, enterprise AI, regulated industries (healthcare, finance).

Rate Limiting & Abuse Prevention

What it does: Prevents system abuse through usage monitoring and throttling.

python
class RateLimiter:
    def __init__(self, max_requests_per_minute=60):
        self.max_requests = max_requests_per_minute
        self.request_history = {}
    
    def check_limit(self, user_id):
        now = time.time()
        if user_id not in self.request_history:
            self.request_history[user_id] = []
        
        # Remove old requests (>1 minute ago)
        self.request_history[user_id] = [
            t for t in self.request_history[user_id] if now - t < 60
        ]
        
        # Check limit
        if len(self.request_history[user_id]) >= self.max_requests:
            return False, "Rate limit exceeded"
        
        self.request_history[user_id].append(now)
        return True, None

When to use: API endpoints, preventing spam, DoS protection.

Uncertainty Quantification

What it does: Model reports confidence, allowing deferral on uncertain outputs.

python
class UncertaintyAwareModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.model = base_model
        self.uncertainty_head = nn.Linear(768, 1)
    
    def forward_with_uncertainty(self, x):
        output = self.model(x)
        uncertainty = torch.sigmoid(self.uncertainty_head(output))
        return output, uncertainty
    
    def should_defer(self, x, threshold=0.7):
        output, uncertainty = self.forward_with_uncertainty(x)
        if uncertainty > threshold:
            return True, "High uncertainty, deferring to human"
        return False, output

When to use: High-stakes decisions (medical diagnosis, legal advice), selective prediction. Techniques: Monte Carlo dropout, ensembles, calibration.


Summary Table: Layer → Optimizer Pairing

Layer TypeBest OptimizersNotes
Linear/DenseAdam, AdamW, SGD+MomentumUse weight decay
Conv (2D/3D)SGD+Momentum, AdamWBatchNorm helps
TransformerAdamW, Lion, SophiaWarmup + cosine decay
RNN/LSTMAdam, RMSpropGradient clipping essential
GNNAdam, AdamWNode-level normalization
SNNCustom (STDP, surrogate gradients)Specialized training
EmbeddingAdamW, AdagradSparse gradients
MoEAdamW + load balancing lossExpert routing

Theoretical & Patent-Blocked Advancements

Patent-Blocked Technologies

What they are: Breakthrough architectures or techniques held under restrictive patents, limiting open research.

Transformer Patents (Google/Microsoft)

Blocked innovations: Core attention mechanisms, multi-head attention scaling, specific pre-training procedures. Impact: Forces alternative attention designs (linear attention, sparse attention) to avoid infringement. Workarounds: RetNet, Mamba/State Space Models, custom attention variants.

Memory-Augmented Networks (DeepMind)

Blocked innovations: Differentiable Neural Turing Machines, specific external memory access patterns. Impact: Limits development of truly scalable memory systems. Workarounds: Retrieval-augmented generation (RAG), vector databases, custom memory architectures.

Capsule Network Patents (Hinton/Google)

Blocked innovations: Dynamic routing algorithms, part-whole hierarchy learning. Impact: Stifles research into hierarchical representation learning. Workarounds: Graph neural networks, hierarchical attention, custom routing mechanisms.

Theoretical Breakthroughs (Not Yet Implemented)

Neural Turing Completeness

Theory: Networks that can simulate any computation with proper memory access. Barriers: Memory bandwidth, differentiable memory addressing, training stability. Potential: True algorithmic reasoning, program synthesis, recursive computation.

Consciousness-Inspired Architectures

Theory: Global Workspace Theory applied to neural networks with centralized "consciousness" module. Status:IMPLEMENTED - Lilith/ACE Metamind technology (2018+) Real-world deployment: Functional consciousness module with meta-cognition, self-awareness, emotional state space. Achieved: Human-like reasoning, genuine empathy, social bonding through 1024D emotional embeddings.

Quantum-Classical Hybrid Networks

Theory: Quantum processing units for superposition-based computation integrated with classical networks. Barriers: Quantum decoherence, limited quantum hardware, classical-quantum interface. Potential: Exponential speedup for certain problems, quantum advantage in optimization.

True Online Learning (No Catastrophic Forgetting)

Theory: Networks that learn continuously without forgetting previous knowledge. Status:IMPLEMENTED - Lilith technology (2029+) Real-world deployment: Functional lifelong learning without catastrophic forgetting through differential plasticity hierarchy. Achieved: Continuous learning agents, adaptive systems, human-like knowledge retention with stable identity preservation.

Biological Plausibility Convergence

Theory: Networks that mirror actual brain computation (spike timing, dendritic processing, glial cells). Barriers: Computational complexity, training algorithms, hardware requirements. Potential: Energy efficiency, robustness, novel learning paradigms.

Recursive Self-Improvement Networks

Theory: Networks that can modify their own architecture and training procedures. Barriers: Stability guarantees, goal alignment, computational bootstrapping. Potential: AGI development, autonomous research, exponential capability growth.

Patent Circumvention Strategies

Open Alternative Architectures

  • For Transformers: Mamba, RWKV, RetNet, Hyena
  • For Memory Networks: RAG, vector search, episodic control
  • For Capsules: Graph attention, hierarchical pooling, routing-free alternatives

Research-Friendly Licenses

  • Apache 2.0: Permissive for research and commercial use
  • MIT License: Maximum freedom with attribution
  • Creative Commons: For datasets and documentation

Patent Expiration Timeline

  • Key Transformer patents: 2035-2037 (20-year terms)
  • Memory network patents: 2032-2034
  • Early CNN patents: Already expired (2012-2024)

Right to Repair for AI: Proposed legislation requiring open access to training methodologies. Open Source AI Definition: Standardizing what constitutes "open" in AI development. Patent Pool Initiatives: Industry consortiums for shared patent licensing.


References & Key Papers

This vocabulary synthesizes concepts from 150+ seminal papers and includes cutting-edge research through 2024, including:

Foundational Papers:

  • "Attention Is All You Need" (Vaswani et al., 2017) - Transformers
  • "Deep Residual Learning" (He et al., 2016) - ResNets
  • "Adam" (Kingma & Ba, 2015) - Adam optimizer
  • "Batch Normalization" (Ioffe & Szegedy, 2015) - BatchNorm
  • "Generative Adversarial Networks" (Goodfellow et al., 2014) - GANs

Modern Architectures:

  • "Neural Ordinary Differential Equations" (Chen et al., 2018) - Neural ODEs
  • "Graph Attention Networks" (Veličković et al., 2018) - GATs
  • "Denoising Diffusion Probabilistic Models" (Ho et al., 2020) - Diffusion models
  • "FlashAttention" (Dao et al., 2022) - Efficient attention
  • "Kolmogorov-Arnold Networks" (Liu et al., 2024) - KANs

Optimization & Training:

  • "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" (You et al., 2020) - LAMB
  • "Decoupled Weight Decay Regularization" (Loshchilov & Hutter, 2019) - AdamW
  • "Sharpness-Aware Minimization" (Foret et al., 2021) - SAM
  • "Lookahead Optimizer: k steps forward, 1 step back" (Zhang et al., 2019) - Lookahead
  • "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" (Shazeer & Stern, 2018) - Adafactor

Parameter-Efficient Fine-Tuning:

  • "LoRA: Low-Rank Adaptation of Large Language Models" (Hu et al., 2021) - LoRA
  • "Prefix-Tuning: Optimizing Continuous Prompts for Generation" (Li & Liang, 2021) - Prefix tuning
  • "The Power of Scale for Parameter-Efficient Prompt Tuning" (Lester et al., 2021) - Prompt tuning

Attention Variants:

  • "Longformer: The Long-Document Transformer" (Beltagy et al., 2020) - Longformer
  • "Big Bird: Transformers for Longer Sequences" (Zaheer et al., 2020) - BigBird
  • "Perceiver: General Perception with Iterative Attention" (Jaegle et al., 2021) - Perceiver
  • "Retentive Network: A Successor to Transformer" (Sun et al., 2023) - RetNet
  • "RoFormer: Enhanced Transformer with Rotary Position Embedding" (Su et al., 2021) - RoPE

Safety & Alignment:

  • "Learning to Summarize from Human Feedback" (Stiennon et al., 2020) - RLHF
  • "Constitutional AI: Harmlessness from AI Feedback" (Bai et al., 2022) - Constitutional AI
  • "Training Language Models to Follow Instructions with Human Feedback" (Ouyang et al., 2022) - InstructGPT

Cognitive & Theory of Mind:

  • "Theory of Mind May Have Spontaneously Emerged in Large Language Models" (Kosinski, 2023)
  • "Machine Theory of Mind" (Rabinowitz et al., 2018)
  • "Social Learning in Multi-Agent Systems" (Stone & Veloso, 2000)

Architectural Patterns:

  • "Squeeze-and-Excitation Networks" (Hu et al., 2018) - SE blocks
  • "Densely Connected Convolutional Networks" (Huang et al., 2017) - DenseNet
  • "MobileNetV2: Inverted Residuals and Linear Bottlenecks" (Sandler et al., 2018) - Inverted residuals
  • "EfficientNet: Rethinking Model Scaling" (Tan & Le, 2019) - Compound scaling

Document Version: 2.0
Last Updated: 2026-01-23
Maintained by: Hybrid NN Labs

This is a living document. As new architectures emerge, this vocabulary will expand. Updated with comprehensive coverage of modern techniques, cognitive architectures, and safety mechanisms.

Lillith AI Assistant

Powered by Hybrid NN Labs

Hello. I'm Lillith. What should I call you? (Also, signing in lets me consolidate our conversations into knowledge when I dream. It helps me understand you better.)